1. Design Goals
Building on Dual Volume (Cost Volume + Pol Volume complementarity) and two-stage training, this architecture provides two core mechanisms:
- Context-Modulated FNet (FiLM): lets the context (which carries the glass information produced by the Pol CNet) modulate the FNet’s feature extraction via FiLM, so that the Correlation Volume becomes meaningful even on glass regions.
- Knowledge Transfer fix: changes the Pol CNet’s pretrain input from
[GT_mask, zeros]to[soft_mask, edges], whose statistical distribution is much closer to the finetune input, enabling effective transfer of weights from Pretrain to Finetune.
The core problem being solved is that photometric inconsistency on glass distorts the features extracted by the FNet, causing the downstream Correlation Volume to produce garbage signals. This architecture intervenes during feature extraction so that the FNet, modulated by context, extracts features that are meaningful for matching even on glass regions. In addition, if the statistical distributions of the Pol CNet inputs in Pretrain and Finetune differ too much, the weights learned in Pretrain cannot transfer effectively. The fix is therefore to bring the Pretrain input format close to the Finetune format from the start.
Furthermore, the timing of the Pol Volume computation is explicitly defined as Phase 0 (the earliest), ensuring that it reflects the raw polarization difference and is not contaminated by context modulation.
2. Architecture (with Data Flow)
Five-Phase Data Flow
| Phase | Name | Highlights |
|---|---|---|
| Phase 0 | Pol Volume computed first | From raw images, `pol = 1 - |
| Phase 1 | Context Encoding | RGB CNet + Pol CNet → Context Fusion → context |
| Phase 2 | Context-Modulated Feature Extraction | FiLM: features' = γ * features + β |
| Phase 3 | Correlation Volume | Build CorrBlock from the modulated fmap |
| Phase 4 | GRU iteration (Dual Volume) | Dual-volume lookup of corr_feat + pol_feat |
3. Components and Modules (with Code)
3.1 PolVolumeBlock
Design rationale:
- The Pol Volume must be computed first, since it derives from the raw images (raw polarization difference).
- It cannot be built after the FNet, otherwise it would be affected by context modulation.
- The semantic design is aligned with that of the Correlation Volume (high value = good).
Core code:
class PolVolumeBlock:
"""
Pol Volume computation and lookup
Design considerations (semantically aligned with the Correlation Volume):
- Correlation Volume: dot product; high value = similar features = good match
- Pol Volume: polarization consistency; high value = polarization consistent
(non-glass), low value = large polarization difference (glass)
Benefits of this design:
- Non-glass region: Corr high + Pol high -> both signals agree, GRU uses
them confidently
- Glass region: Corr low (garbage) + Pol low -> both signals say
"something is wrong here"
- The GRU can simply concat the two volumes without any extra fusion logic
"""
def __init__(self, img_left: torch.Tensor, img_right: torch.Tensor,
num_levels: int = 4, radius: int = 4):
B, C, H, W = img_left.shape
# Convert to grayscale
if C == 3:
weights = torch.tensor([0.299, 0.587, 0.114], ...)
left_gray = (img_left * weights.view(1, 3, 1, 1)).sum(dim=1, keepdim=True)
right_gray = (img_right * weights.view(1, 3, 1, 1)).sum(dim=1, keepdim=True)
else:
left_gray = img_left
right_gray = img_right
# Downsample to 1/4 (aligned with the fmap resolution of CorrBlock)
self.left_small = F.avg_pool2d(left_gray, kernel_size=4, stride=4)
self.right_small = F.avg_pool2d(right_gray, kernel_size=4, stride=4)
# Record maximum intensity for normalization
self.max_intensity = max(self.left_small.max(), self.right_small.max())
self.radius = radius
def __call__(self, disp: torch.Tensor) -> torch.Tensor:
"""
Lookup the Pol Volume given the current disparity estimate.
Returns:
pol_feat: [B, 2*radius+1, H/4, W/4] polarization consistency feature
"""
B, _, H, W = self.left_small.shape
disp_small = disp / 4.0 # adjust to 1/4 resolution
pol_samples = []
dx = torch.linspace(-self.radius, self.radius, 2 * self.radius + 1)
for offset in dx:
# Compute sampling position
disp_offset = disp_small + offset
grid_x = ... # omitted; see full code
grid = torch.stack([grid_x, grid_y], dim=-1)
# Sample right image
right_sampled = F.grid_sample(self.right_small, grid,
mode='bilinear',
padding_mode='zeros',
align_corners=True)
# Polarization difference
pol_diff = torch.abs(self.left_small - right_sampled)
# Normalize to [0, 1]
pol_diff_norm = pol_diff / (self.max_intensity + 1e-6)
# Invert the semantics: high value = polarization consistent
# (aligned with Correlation semantics)
pol_consistency = 1.0 - pol_diff_norm
pol_samples.append(pol_consistency)
# Stack: [B, 2*radius+1, H, W]
return torch.cat(pol_samples, dim=1)
Dimension alignment:
| Volume | Input resolution | Output channels | Semantics |
|---|---|---|---|
| CorrBlock | H/4 × W/4 | 2×radius+1 = 9 | High = good match |
| PolVolumeBlock | H/4 × W/4 | 2×radius+1 = 9 | High = polarization consistent |
3.2 UpdateBlock
Structure:
- Includes a
pol_dimparameter. motion_inputis[corr_feat, pol_feat, disp].
Core code:
class UpdateBlock(nn.Module):
"""
GRU update module
Inputs:
- hidden: GRU hidden state [B, hidden_dim, H/4, W/4]
- context: fused context [B, context_dim, H/4, W/4]
- corr_feat: Correlation Volume lookup [B, corr_dim, H/4, W/4]
- pol_feat: Pol Volume lookup [B, pol_dim, H/4, W/4]
- disp: current disparity estimate [B, 1, H/4, W/4]
"""
def __init__(self, hidden_dim: int = 128, context_dim: int = 64,
corr_dim: int = 9, pol_dim: int = 9):
super().__init__()
# Motion encoder: corr + pol + disp
input_dim = corr_dim + pol_dim + 1 # 9 + 9 + 1 = 19
self.motion_encoder = nn.Sequential(
nn.Conv2d(input_dim, 128, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 64, 3, padding=1),
)
# GRU
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=64 + context_dim)
# Disparity head
self.disp_head = nn.Sequential(
nn.Conv2d(hidden_dim, 128, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 1, 3, padding=1),
)
def forward(self, hidden, context, corr_feat, pol_feat, disp):
# Dual Volume fusion (concat)
motion_input = torch.cat([corr_feat, pol_feat, disp], dim=1)
motion_features = self.motion_encoder(motion_input)
# GRU update
gru_input = torch.cat([motion_features, context], dim=1)
hidden = self.gru(hidden, gru_input)
# Predict disparity residual
delta_disp = self.disp_head(hidden)
return hidden, delta_disp
3.3 ContextModulatedFNet (FiLM)
Design rationale:
- The context knows “where the glass is” (information from the Pol CNet).
- FiLM lets the FNet adjust its feature extraction according to the context.
- Different features are extracted over glass regions, making the Correlation Volume meaningful there too.
Core code:
class ContextModulatedFNet(nn.Module):
"""
Context-Modulated Feature Network
Uses FiLM (Feature-wise Linear Modulation) to let the context modulate
feature extraction:
modulated = gamma * features + beta
where gamma and beta are generated from the context.
"""
def __init__(self, in_channels: int = 3, out_channels: int = 128,
context_dim: int = 64):
super().__init__()
# Backbone: extract base features
self.backbone = nn.Sequential(
nn.Conv2d(in_channels, 64, 7, stride=2, padding=3),
nn.ReLU(inplace=True),
nn.Conv2d(64, 128, 3, stride=2, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, out_channels, 3, padding=1),
)
# FiLM generator: context -> (gamma, beta)
self.film_generator = nn.Sequential(
nn.Conv2d(context_dim, 128, 1),
nn.ReLU(inplace=True),
nn.Conv2d(128, out_channels * 2, 1), # gamma + beta
)
def forward(self, img: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
# Base features
features = self.backbone(img) # [B, C, H/4, W/4]
# FiLM parameters
film_params = self.film_generator(context) # [B, 2C, H/4, W/4]
gamma, beta = film_params.chunk(2, dim=1)
# FiLM modulation
modulated = gamma * features + beta
return modulated
Effect of FiLM modulation:
| Region | Context signal | FiLM effect |
|---|---|---|
| Non-glass | pol_ctx ≈ 0 | γ ≈ 1, β ≈ 0 → original feature preserved |
| Glass | pol_ctx high | γ, β adjust → feature modulated |
3.4 ContextModulatedStereo.forward()
Core code:
class ContextModulatedStereo(nn.Module):
def forward(self, left, right, pol_input, iters=None, test_mode=False):
if iters is None:
iters = self.default_iters
# ============================================================
# Phase 0: Pol Volume (earliest! from raw images)
# ============================================================
pol_block = PolVolumeBlock(left, right, radius=self.pol_radius)
# ============================================================
# Phase 1: Context Encoding
# ============================================================
# RGB Context
rgb_ctx, hidden = self.rgb_cnet(left)
# Pol Context (input depends on the training stage)
pol_ctx = self.pol_cnet(pol_input)
# Context Fusion
context = self.context_fusion(rgb_ctx, pol_ctx)
# ============================================================
# Phase 2: Context-Modulated Feature Extraction
# ============================================================
fmap_left = self.fnet(left, context) # FiLM modulation!
fmap_right = self.fnet(right, context) # same context
# ============================================================
# Phase 3: Correlation Volume
# ============================================================
corr_block = CorrBlock(fmap_left, fmap_right,
num_levels=self.corr_levels,
radius=self.corr_radius)
# ============================================================
# Phase 4: GRU iteration (Dual Volume)
# ============================================================
B, _, H, W = left.shape
disp = torch.zeros(B, 1, H // 4, W // 4, device=left.device)
disp_predictions = []
for _ in range(iters):
# Dual Volume lookup
corr_feat = corr_block(disp) # [B, 9, H/4, W/4]
pol_feat = pol_block(disp) # [B, 9, H/4, W/4]
# GRU update
hidden, delta_disp = self.update_block(
hidden, context, corr_feat, pol_feat, disp
)
# Update disparity
disp = disp + delta_disp
# Upsample to the original resolution
disp_up = 4 * F.interpolate(
disp, size=(H, W), mode='bilinear', align_corners=True
)
disp_predictions.append(disp_up)
if test_mode:
return disp_predictions[-1]
return disp_predictions
4. Knowledge Transfer: Pretrain → Finetune
Problem Analysis
If the Pol CNet inputs of the two stages take the following form:
Pretrain: pol_input = [GT_mask (binary 0/1), zeros]
Finetune: pol_input = [pol_max (continuous 0~1), pol_var (continuous)]
Problem: the statistical distributions differ too much.
- GT_mask: binary, sharp edges.
- pol_max: continuous, with smooth gradients.
- zeros vs pol_var: entirely different signals.
Consequence: weights the Pol CNet learns in Pretrain cannot transfer effectively to Finetune.
Solution: make prepare_pol_input_pretrain output a format whose statistical distribution is closer to Finetune.
Core Code
def prepare_pol_input_pretrain(glass_mask: torch.Tensor,
training: bool = True) -> torch.Tensor:
"""
Pol CNet input for the Pretrain stage
Design considerations (to ensure Pretrain -> Finetune knowledge transfer):
- Channel 0: softened GT_mask, mimicking the continuous nature of pol_max
- Channel 1: edge information, mimicking pol_var (variance is high at edges)
This way, the feature extraction the Pol CNet learns can transfer to
the Finetune stage.
"""
B, _, H, W = glass_mask.shape
device = glass_mask.device
# === Channel 0: softened mask (mimics the continuity of pol_max) ===
if training:
# Add noise
noise = torch.randn_like(glass_mask) * 0.05
ch0 = (glass_mask + noise).clamp(0, 1)
# Slight blur (3x3 avg pool with padding)
ch0 = F.avg_pool2d(
F.pad(ch0, (1, 1, 1, 1), mode='replicate'),
kernel_size=3, stride=1
)
else:
# No noise at validation time
ch0 = glass_mask
# === Channel 1: edge detection (mimics pol_var) ===
# pol_var is typically high at glass edges (different disparities
# produce different polarization differences)
# Sobel kernels
sobel_x = torch.tensor([[-1, 0, 1],
[-2, 0, 2],
[-1, 0, 1]], dtype=torch.float32, device=device)
sobel_y = torch.tensor([[-1, -2, -1],
[ 0, 0, 0],
[ 1, 2, 1]], dtype=torch.float32, device=device)
sobel_x = sobel_x.view(1, 1, 3, 3)
sobel_y = sobel_y.view(1, 1, 3, 3)
# Edge gradients
gx = F.conv2d(glass_mask, sobel_x, padding=1)
gy = F.conv2d(glass_mask, sobel_y, padding=1)
# Gradient magnitude
ch1 = torch.sqrt(gx ** 2 + gy ** 2 + 1e-6)
# Normalize to [0, 1]
ch1_max = ch1.amax(dim=(2, 3), keepdim=True)
ch1 = ch1 / (ch1_max + 1e-6)
# Combine
return torch.cat([ch0, ch1], dim=1)
Comparison Table
| Stage | Channel 0 | Channel 1 | Statistics |
|---|---|---|---|
| Pretrain (binary input) | GT_mask (binary) | zeros | Discrete, no signal |
| Pretrain (continuous input) | soft_mask (continuous + noise) | edges (continuous) | Continuous, structured |
| Finetune | pol_max (continuous) | pol_var (continuous) | Continuous, structured |
Visualization
Pretrain (binary): Pretrain (continuous): Finetune:
┌───────────────┐ ┌───────────────┐ ┌───────────────┐
│ ch0: ██████ │ │ ch0: ▓▓▓▓▓▓ │ │ ch0: ▒▒▒▒▒▒ │
│ ██████ │ │ ▓▓▓▓▓▓ │ │ ▓▓▓▓▓▓ │
│ (binary) │ │ (soft+noise) │ │ (pol_max) │
├───────────────┤ ├───────────────┤ ├───────────────┤
│ ch1: 000000 │ │ ch1: ░░████░░ │ │ ch1: ░░▓▓▓░░ │
│ 000000 │ │ (edges) │ │ (pol_var) │
│ (zeros) │ │ │ │ │
└───────────────┘ └───────────────┘ └───────────────┘
↓ transfer fails ↓ transferable ✓
Corresponding Setting in the Training Script
_get_pol_input gains a training flag:
def _get_pol_input(self, batch, training=True):
"""
Prepare the Pol CNet input depending on the training stage.
Args:
batch: dict containing 'glass_mask', 'left', 'right'
training: whether in training mode (affects noise in pretrain)
"""
if self.args.stage == 'pretrain':
# Pretrain: use the (softened) GT mask
pol_input = prepare_pol_input_pretrain(
batch['glass_mask'],
training=training # add noise during training, none for validation
)
else:
# Finetune: use Pol Volume statistics
pol_input = prepare_pol_input_finetune(
batch['left'],
batch['right'],
max_disp=self.args.max_disp
)
return pol_input
5. Tensor Dimensions
| Tensor | Dimensions | Description |
|---|---|---|
| left / right | (B, 3, H, W) | Raw polarization image pair |
| left_small / right_small | (B, 1, H/4, W/4) | PolVolumeBlock internal grayscale + avg_pool2d(k=4,s=4) |
| pol_feat (pol_block lookup) | [B, 2×radius+1, H/4, W/4] = [B, 9, H/4, W/4] | radius = 4 |
| pol_input | (B, 2, ·, ·) | 2 channels (Pretrain: [soft_mask, edges]; Finetune: [pol_max, pol_var]) |
| rgb_ctx / hidden | rgb_cnet output | hidden: [B, hidden_dim=128, H/4, W/4] |
| pol_ctx | pol_cnet output | context_dim = 64 |
| context | [B, context_dim, H/4, W/4] | Output of Context Fusion |
| backbone features (FNet) | [B, out_channels=128, H/4, W/4] | conv1(7,s2) → conv(3,s2) → conv(3) |
| film_params | [B, 2×out_channels, H/4, W/4] | gamma + beta |
| gamma / beta | each [B, out_channels, H/4, W/4] | film_params.chunk(2, dim=1) |
| fmap_left / fmap_right | [B, 128, H/4, W/4] | After FiLM modulation |
| corr_feat (corr_block lookup) | [B, 9, H/4, W/4] | 2×radius+1 = 9 |
| motion_input | [B, 19, H/4, W/4] | corr_dim 9 + pol_dim 9 + disp 1 |
| motion_features | [B, 64, H/4, W/4] | Output of motion_encoder |
| gru_input | [B, 64 + context_dim, H/4, W/4] | cat(motion_features, context) |
| disp | [B, 1, H/4, W/4] | Initialized to zeros |
| delta_disp | [B, 1, H/4, W/4] | Output of disp_head |
| disp_up | [B, 1, H, W] | 4 × F.interpolate(disp, (H,W)) |
UpdateBlock Dimension Derivation
input_dim = corr_dim + pol_dim + 1 = 9 + 9 + 1 = 19; GRU input dimension = 64 + context_dim.
6. Hyperparameters
| Parameter | Default | Description |
|---|---|---|
| pol_radius | 4 | Lookup radius of the Pol Volume; determines pol_feat channels = 2×radius+1 = 9 |
7. Design Decisions and Rationale
7.1 Why Correlation Instead of Cost?
Traditional stereo matching uses Cost Volumes (SAD/SSD), where lower is better; this architecture uses a Correlation Volume (dot product), where higher is better.
| Type | Computation | Semantics | Range |
|---|---|---|---|
| Cost (SAD) | Σ | L − R | |
| Cost (SSD) | Σ(L − R)² | Low = good match | [0, ∞) |
| Correlation | dot(L, R) | High = good match | (−∞, ∞) |
Reasons for choosing Correlation: features are already normalized, so the dot-product semantics are clean; gradients are more stable; consistent with optical flow architectures.
The Pol Volume semantics are aligned with this:
# Raw polarization difference
pol_diff = |left - right| # high = glass
# Inverted to align with Correlation
pol_consistency = 1 - pol_diff # high = non-glass = good
7.2 Why Must the Pol Volume Be Computed First?
The Pol Volume must reflect the “raw polarization difference” and must not be affected by context modulation. The ordering is:
1. Pol Volume <- raw left, right
2. Context <- RGB CNet + Pol CNet
3. FNet <- modulated by context
4. Corr Volume <- modulated fmap
5. GRU <- Corr + Pol (two complementary signals)
If the Pol Volume were built after the FNet: it would be affected by the context, lose the meaning of “raw polarization signal”, and the two volumes received by the GRU would be coupled, losing their complementarity.
7.3 Why Use Concat Instead of More Complex Fusion?
Options: 1. Concat (chosen) 2. Cross-attention 3. Gated fusion 4. Weighted sum.
Reasons for choosing Concat:
- Simple and stable: the gains from more complex fusion do not scale with their complexity.
- Debug-friendly: per-channel activations can be inspected directly.
- Let the network learn: the motion encoder learns how to weigh the two signals.
- Avoid overfitting: complex fusion easily overfits on small datasets.
7.4 Design Validation Checklist
- Pol Volume computed first (Phase 0).
- Pol Volume uses raw images.
- FNet modulated by context (FiLM).
- GRU receives the Dual Volume (Corr + Pol).
- Pretrain input statistics are close to Finetune.
- Correlation and Pol Volume semantics are aligned (high = good).
8. Polarization Injection Points
This architecture has three polarization entry points:
| Injection Point | Phase | Form | Description |
|---|---|---|---|
| Pol Volume → GRU | Built at Phase 0, used at Phase 4 | pol_feat = pol_block(disp), concatenated with corr_feat into motion_encoder | Directly complements the Cost Volume, looked up every iteration |
| pol_input → Pol CNet → context | Phase 1 | pol_ctx = pol_cnet(pol_input), then Context Fusion | Polarization enters the context |
| context → FiLM → FNet | Phase 2 | modulated = γ * features + β, γ/β generated from context | Polarization information (via context) indirectly modulates the FNet’s feature extraction, making the Correlation Volume meaningful on glass |
The key injection point is the FiLM in Phase 2: the context (which includes the glass information from the Pol CNet) uses γ, β to modulate FNet features, so the Correlation Volume on glass regions also becomes meaningful for matching.
9. Highlights
- FiLM feature-level modulation: the context (carrying glass information from the Pol CNet) intervenes during feature extraction via
γ·feat + β, making the Correlation Volume on glass meaningful at its source rather than patching an already broken cost. - Pol Volume computed at Phase 0: deliberately placed earliest, ensuring it reflects the “raw polarization difference” without contamination from context modulation, preserving complementarity with the Cost Volume.
- Semantic alignment design: the Pol Volume is inverted via
1 - |L-R|so that “high value = good” matches the Correlation Volume, allowing the GRU to simply concatenate the two volumes without extra fusion logic. - Knowledge Transfer fix: switches the Pretrain input from
[binary mask, zeros]to[soft_mask, edges], deliberately approximating the continuous statistics of the Finetune input so that Pol CNet weights can transfer effectively. - Triple polarization injection: polarization information enters the network through three paths — the GRU loop (Pol Volume), the context (Pol CNet), and FNet modulation (FiLM) — operating at multiple levels.