1. Design Goals
This architecture targets the matching difficulty over glass regions in active polarization stereo systems, addressing the following three core problems.
Problem 1: The Polarization Context Encoder Receives Too Little Information
If the polarization context encoder (Pol CNet) only consumes the binary GT mask (0/1), the information content is extremely low and the Pol CNet cannot learn a representation different from the RGB CNet — no matter how much data you feed it, it still only sees 0s and 1s. This is fundamentally not a data-quantity problem. This architecture lets the Pol CNet directly consume the real polarization Cost Volume, providing a continuous, structured polarization signal.
Problem 2: Soft Modulation Is Insufficient to Move a Large Pretrained FNet
Treating glass regions with FiLM soft modulation:
modulated_features = gamma * features + beta
A large stereo-matching dataset trains the FNet weights on hundreds of thousands of samples; trying to leverage them with gradients produced by 10–50% of glass regions in a thousand scenes is simply the wrong lever ratio — the FNet is essentially not modulated. This architecture therefore does not touch the FNet at all and instead injects the polarization signal through a parallel path.
Problem 3: The Cost Volume Is Passive
Whatever features the FNet outputs, that is what Correlation computes. Photometric inconsistency on glass has already broken FNet features, so the downstream Correlation is inevitably garbage. Rather than patching an already broken cost, it is better to tell the GRU at the context level that “the cost here is unreliable”.
Design Philosophy
| Design Choice | Description |
|---|---|
| FNet untouched | The existing RGB FNet stays frozen |
| Weighted fusion of dual Cost Volumes | RGB and Pol Cost Volumes, weighted by α |
| Pol CNet consumes PolCostVolume | Real polarization signal, not GT mask |
| Add only a parallel path | Do not modify the core RAFT-Stereo architecture |
GT disparity is used only for supervision and never as input to any module, so there is no Oracle–Real Gap.
2. Full Architecture Diagram: True Dual-Stream
3. Components and Modules
3.1 PolCostVolume: Polarization Difference Volume
Purpose: compute the polarization difference at every possible disparity for each pixel, without warping using GT disparity.
class PolCostVolume(nn.Module):
"""Compute the polarization-difference Cost Volume"""
def __init__(self, max_disp=192):
self.max_disp = max_disp
def forward(self, left, right):
"""
Args:
left: (B, C, H, W) - I∥
right: (B, C, H, W) - I⊥
Returns:
pol_volume: (B, D, H, W) - polarization difference for each disparity
"""
B, C, H, W = left.shape
pol_volume = []
for d in range(self.max_disp):
if d == 0:
diff = torch.abs(left - right)
else:
right_shifted = F.pad(right, (d, 0, 0, 0))[:, :, :, :W]
diff = torch.abs(left - right_shifted)
diff = diff.mean(dim=1, keepdim=True) # (B, 1, H, W)
pol_volume.append(diff)
return torch.cat(pol_volume, dim=1) # (B, D, H, W)
Physical meaning:
- Glass region: at the correct disparity,
left >> right, so pol_volume shows a clear peak. - Non-glass region:
left ≈ right, so pol_volume is overall flat.
3.2 PolVolumeEncoder: Low-Dimensional Abstraction
Purpose: compress the high-dimensional pol_volume into a low-dimensional feature with a 3D CNN, performing denoising and feature extraction.
class PolVolumeEncoder(nn.Module):
"""Encode pol_volume into a low-dimensional feature"""
def __init__(self, max_disp=192, out_dim=8):
super().__init__()
# Compress along the disparity dimension
self.encoder = nn.Sequential(
nn.Conv3d(1, 8, (7, 3, 3), stride=(4, 2, 2), padding=(3, 1, 1)),
nn.ReLU(),
nn.Conv3d(8, 16, (5, 3, 3), stride=(4, 1, 1), padding=(2, 1, 1)),
nn.ReLU(),
nn.Conv3d(16, out_dim, (3, 3, 3), stride=(2, 1, 1), padding=(1, 1, 1)),
)
# Finally squeeze the D dimension
self.squeeze = nn.AdaptiveAvgPool3d((1, None, None))
def forward(self, pol_volume):
# pol_volume: (B, D, H, W) -> (B, 1, D, H, W)
x = pol_volume.unsqueeze(1)
x = self.encoder(x) # (B, out_dim, D', H/4, W/4)
x = self.squeeze(x) # (B, out_dim, 1, H/4, W/4)
return x.squeeze(2) # (B, out_dim, H/4, W/4)
3.3 Pol FNet
The feature extractor of the Pol path in the dual stream, using 1/8 of the RGB FNet channels (256 → 32). Input is pol_feature (8ch) + left (3ch) = 11ch; output is a 32-channel feature map.
3.4 Pol CNet (with α Prediction Head)
Input: pol_feature (8ch) + fmap_pol (32ch) = 40ch.
Output: context_pol (64), hidden_pol (128), α (1).
class PolCNet(nn.Module):
"""
Input: pol_feature (8ch) + fmap_pol (32ch) = 40ch
Output: context_pol (64), hidden_pol (128), α (1)
"""
def forward(self, pol_feature, fmap_pol):
# Fuse low-level polarization statistics + high-level learned features
x = torch.cat([pol_feature, fmap_pol], dim=1) # (B, 40, H/4, W/4)
x = self.encoder(x)
context_pol = self.context_head(x)
hidden_pol = self.hidden_head(x)
alpha = torch.sigmoid(self.alpha_head(x))
return context_pol, hidden_pol, alpha
pol_feature (8ch) is the raw polarization Cost Volume statistic output by the PolVolumeEncoder; fmap_pol (32ch) is the high-level semantic feature output by the Pol FNet. Combining them lets the Pol CNet see both the “raw polarization signal” and the “learned feature”.
3.5 RGB FNet / RGB CNet
These reuse the original RAFT-Stereo structure unchanged, stay frozen in the dual stream, and serve as a floor on the original RAFT performance.
4. Tensor Dimensions
| Module | Input/Output Dimensions | Description |
|---|---|---|
| PolCostVolume | (B,3,H,W) ×2 → (B, 192, H, W) | Polarization difference for each disparity |
| PolVolumeEncoder | (B,192,H,W) → (B, 8, H/4, W/4) | Low-dim feature after 3D Conv compression |
| Pol FNet | 11ch → (B, 32, H/4, W/4) | Input: pol_feature(8) + left(3) |
| RGB FNet | 3ch → (B, 256, H/4, W/4) | Original RAFT-Stereo |
| Pol CNet | 40ch → 64+128+1 | Input: pol_feature(8) + fmap_pol(32) |
| RGB CNet | 3ch → 64+128 | Original RAFT-Stereo |
| cost_pol | (B, 36, H/4, W/4) | 4-level pyramid × 9 samples |
| cost_rgb | (B, 36, H/4, W/4) | 4-level pyramid × 9 samples |
| cost_fused | (B, 36, H/4, W/4) | α-weighted fusion |
| context_fused | (B, 128, H/4, W/4) | concat(rgb 64, pol 64) |
| hidden_fused | (B, 256, H/4, W/4) | concat(rgb 128, pol 128) |
| α | (B, 1, H/4, W/4) | Per-pixel fusion weight, output by Pol CNet |
5. Hyperparameters
| Parameter | Value | Description |
|---|---|---|
| max_disp | 192 | Disparity range of PolCostVolume |
| GRU iterations | 12 | Number of GRU Update Block iterations |
| Loss decay coefficient γ | — | Per-iteration loss weighting weight = γ^(n-i-1) |
| glass_weight | 5.0 | Loss weight on glass edges |
| strict_weight | 0.5 | Extra loss weight on the glass core |
| Background loss weight | 1.0 | Loss weight on non-glass regions |
6. Design Decisions and Rationale
6.1 α Weighting Mechanism
No separate network is built; the α is the Pol CNet output passed through a sigmoid. Rationale: the Pol CNet is already learning “where is the glass”, and α is the direct expression of that judgment — adding another network is redundant complexity.
6.2 Pol FNet Size
Starts with 1/8 of the RGB FNet channels (256 → 32). Rationale: its job is correction, not reconstruction, so a small capacity is enough; start small.
6.3 Pol CNet Output Dimensions
Match those of the RGB CNet (context=64, hidden=128). Rationale: weighted fusion and concat at the end require aligned dimensions, but the internal channel count can be narrow as long as the output layer projects to the correct dimensions.
6.4 Training Strategy: RGB Fully Frozen, Pol Fully Trainable
| Module | Status | Rationale |
|---|---|---|
| RGB FNet | Frozen | Original RAFT floor |
| RGB CNet | Frozen | Original RAFT floor |
| PolVolumeEncoder | Fully trainable | Learns polarization Cost Volume compression from scratch |
| Pol FNet | Fully trainable | Learns polarization features from scratch |
| Pol CNet | Fully trainable | Learns glass judgment + α from scratch |
| GRU input layer | Fully trainable | Adapts to the new dimensions (context 128, hidden 256) |
| Other GRU layers | Frozen | Preserves the original RAFT behavior |
| Disp Head | Frozen | Preserves the original RAFT behavior |
6.5 Loss Design
α is not supervised with the GT mask; it learns by itself from the disparity loss. Rationale: supervising with the GT mask is cheating during training and creates a distribution gap at inference. α should be the model’s own judgment based on cost quality, not an external label.
6.6 GRU Warmup
Not added for now. Rationale: the Pol FNet is already low-dimensional (1/8 channels) and the PolVolumeEncoder has done denoising, so the signal is much cleaner than the raw Pol Volume; add warmup later if GRU convergence has problems.
6.7 PolCostVolume Replaces GT Warp
No disparity is used for warping; the Cost Volume approach is used instead, and GT disparity is used only for loss supervision.
| GT Warp Approach | Cost Volume Approach | |
|---|---|---|
| Role of GT | Input (cheating) | Supervision (normal) |
| Pretrain input | ` | L - warp(R, GT_d) |
| Finetune input | ` | L - warp(R, est_d) |
| Pretrain/Finetune gap | Yes | No |
| Oracle–Real gap | Yes | No |
Therefore Pretrain and Finetune share the exact same pipeline:
1. pol_volume = PolCostVolume(left, right) # no disparity needed
2. pol_feature = PolVolumeEncoder(pol_volume)
3. context_pol, hidden_pol, α = PolCNet(pol_feature)
4. ... (RGB stream, fusion, GRU) ...
5. predicted_disp = GRU output
6. loss = L1(predicted_disp, GT_disp) # GT only appears here!
6.8 Checkpoint Partial Loading Strategy
After the GRU input layer dimensions change from (context=64, hidden=128) to (context=128, hidden=256), original RAFT weights cannot be loaded directly. The solution is Partial Loading + Zero Init:
def load_raft_weights(model, raft_checkpoint):
raft_weights = torch.load(raft_checkpoint)
for name, param in model.named_parameters():
if name in raft_weights:
raft_param = raft_weights[name]
if param.shape == raft_param.shape:
# Exact match, load directly
param.data.copy_(raft_param)
elif param.shape[1] > raft_param.shape[1]:
# Input channels expanded (GRU input layer)
# Load RAFT weights into the RGB part
param.data[:, :raft_param.shape[1], ...].copy_(raft_param)
# Initialize the Pol part to 0
param.data[:, raft_param.shape[1]:, ...].zero_()
Effect: initially the model ≈ original RAFT (Pol weights are 0 and have no effect); during training the Pol contribution grows gradually, smoothly transitioning without breaking existing performance.
7. Highlights
- True parallel dual-stream design: the RGB stream is fully frozen as a performance floor while the Pol stream learns the polarization signal from scratch at full speed; the two paths are entirely parallel and do not modify the RAFT-Stereo core.
- α per-pixel weighted fusion: a single sigmoid output α switches the trusted source per pixel between glass and non-glass regions (α≈1 trusts RGB, α≈0 trusts Pol), with no extra network needed.
- Zero GT leakage: GT disparity appears only in the loss and never as input to any module, so Pretrain and Finetune share an identical pipeline, completely eliminating both the Oracle–Real Gap and the Pretrain/Finetune Gap.
- PolCostVolume + 3D CNN encoding: a full-disparity polarization difference volume replaces GT warp, then a 3D CNN compresses and denoises it into a low-dimensional feature, reducing both noise and compute.
- Partial Loading + Zero Init: the expanded GRU input layer is initialized with “pretrained weights for the RGB part, zeros for the Pol part”, so the starting point equals the original RAFT and the polarization contribution grows smoothly during training.