1. Design Goals
The stereo_feat from the stereo stream and the pol_feat from the polarization stream must be fused into a single feature for the subsequent matching pipeline. The concat + conv fusion is:
fused = torch.cat([stereo_feat, pol_feat], dim=1) # (B, 256, H, W)
fused = self.fusion_conv(fused) # (B, 128, H, W)
Problem: Concatenation treats all spatial positions equally, but the polarization signal is meaningful only in glass regions. Mixing polarization features uniformly into every pixel requires the network to figure out, on its own, “where in this mixed soup of features is the polarization information useful.”
Design goal: Use cross-attention so stereo features can “query” polarization features, automatically learning where polarization information is needed. Attention naturally concentrates weight on relevant regions (glass) and drives weight toward 0 elsewhere—effectively letting the network learn “where to use polarization.”
2. Architecture
- Q (Query): from stereo features.
- K, V (Key, Value): from polarization features.
- The output is residual:
stereo_out = stereo_feat + alpha * attended_stereo. alphais a learnable gate, then multiplied by a training-warmup capalpha_cap.
This architecture uses Pol→Stereo unidirectional attention: only stereo features query polarization information; the reverse direction (stereo→pol) is disabled by default.
3. Components and Modules
3.1 PooledCrossAttention (memory-efficient)
Vanilla attention has memory complexity O(N^2) (N = H*W), which is infeasible on high-resolution feature maps. PooledCrossAttention pools and downsamples K and V:
# Vanilla attention: O(H*W * H*W) = O(N^2)
# Pooled attention: O(H*W * pool_h*pool_w) = O(N * M), M << N
# Memory savings ~64x (pool_size=8)
Q: (B, C, H, W) - keep original resolution
K, V: pooled to (B, C, H/8, W/8) - reduce memory
- Q keeps the original resolution (every pixel can query).
- K, V are pooled to
H/8 × W/8(pool_size=8). - Complexity drops from
O(N^2)toO(N*M)withM << N, saving roughly 64× memory.
3.2 SafeCrossAttentionFusion (stable training)
The attention fusion layer contains multiple safeguards to ensure stable training:
- Alpha Gate:
alphais initialized assigmoid(-5) ≈ 0.007, near 0—early in training, the attention output is hardly injected. - Zero Init: the output projection is initialized to 0—early in training, the attention branch has no effect.
- Alpha Cap Warmup: the
alphacap grows linearly from 0.05 to 1.0 (see §3.3). - Residual Connection:
out = stereo + alpha * attended—even if the attention branch fails, the original stereo features are preserved. - NaN Protection: numerically stable softmax +
nan_to_num(see §3.4).
3.3 Alpha Cap Warmup
The cap on alpha grows linearly with the training step:
def _compute_alpha_cap(self) -> float:
if self.global_step >= self.args.alpha_cap_warmup:
return 1.0
progress = self.global_step / self.args.alpha_cap_warmup
return self.args.alpha_cap_start + progress * (1.0 - self.args.alpha_cap_start)
# 0.05 -> 1.0 over 5000 steps
- Starts at
alpha_cap_start = 0.05, ends at 1.0, growing linearly overalpha_cap_warmup = 5000steps. - Limits the attention branch’s maximum influence early in training, preventing the randomly initialized attention from disturbing the pre-trained backbone.
3.4 NaN Protection
When the attention mask is entirely -inf, softmax produces NaN (0/0). The safeguards:
# 1. Apply the mask only when enough valid values exist
valid_count = (mask_pooled >= 0.5).sum(dim=-1, keepdim=True)
if valid_count.min() > 0:
attn = attn.masked_fill(mask_pooled < 0.5, float('-inf'))
# 2. Numerically stable softmax
attn_max = attn.max(dim=-1, keepdim=True)[0]
attn = attn - attn_max
attn = F.softmax(attn, dim=-1)
# 3. Last line of defense
attn = torch.nan_to_num(attn, nan=0.0)
Three lines of defense: (1) apply the mask only when there are enough valid values; (2) subtract the max before softmax for numerical stability; (3) nan_to_num as a final safety net.
4. Data Flow
- The stereo stream produces
stereo_feat; the polarization stream (PolarizationEncoder) producespol_feat. PooledCrossAttention:Q←stereo_feat(keep original resolutionH × W).K, V←pol_feat, pooled toH/8 × W/8.
- Compute attention scores → NaN-protected softmax → weighted aggregation of V →
attended_stereo. SafeCrossAttentionFusion:stereo_out = stereo_feat + alpha * attended_stereo, wherealpha = sigmoid(learnable_logit) * alpha_capandalpha_capgrows with the warmup.stereo_outenters the correlation pyramid + GRU for iterative refinement → disparity.
5. Tensor Dimensions
| Tensor | Dimensions | Description |
|---|---|---|
stereo_feat | (B, C, H, W) | Stereo-stream features |
pol_feat | (B, C, H, W) | Polarization-stream features |
Q | (B, C, H, W) | Projected from stereo_feat; keeps original resolution |
K, V | (B, C, H/8, W/8) | Projected from pol_feat and pooled (pool_size=8) |
attended_stereo | (B, C, H, W) | Attention output |
stereo_out | (B, C, H, W) | stereo_feat + alpha * attended_stereo |
Model parameter count
| Model | Total params | Trainable params |
|---|---|---|
| Cross-Attention Fusion | 4,363,715 | 4,363,715 |
Parameters are concentrated in PooledCrossAttention (Q, K, V projections) and SafeCrossAttentionFusion (alpha gates, output projections).
6. Hyperparameters
Cross-Attention-Specific Parameters
| Parameter | Default | Description |
|---|---|---|
num_heads | 4 | Number of multi-head attention heads |
pool_size | 8 | K, V pooling size (8 = 64× memory savings) |
enable_stereo_to_pol | False | Whether to enable bidirectional attention (unidirectional by default) |
alpha_cap_start | 0.05 | Alpha-warmup starting value |
alpha_cap_warmup | 5000 | Alpha-warmup steps |
Training Hyperparameters
| Parameter | Value | Description |
|---|---|---|
pretrained | raftstereo-sceneflow.pth | SceneFlow pre-trained weights |
pol_dim | 128 | Polarization feature dimension |
pol_threshold | 0.05 | Soft Threshold value |
pol_weight | 2.0 | Polarization weight in Polarization-aware Loss |
glass_weight | 5.0 | Loss weight on glass regions |
pol_lr_mult | 5.0 | Learning-rate multiplier for polarization layers |
hidden_dim | 128 | GRU hidden state dimension |
context_dim | 128 | Context dimension |
feature_dim | 128 | Feature dimension |
iters | 24 | GRU iterations |
batch_size | 8 | Training batch size |
num_steps | 70000 | Training steps |
lr | 0.0003 | Learning rate |
num_heads | 4 | Attention heads |
pool_size | 8 | K, V pooling size |
alpha_cap_start | 0.05 | Alpha-warmup starting value |
alpha_cap_warmup | 5000 | Alpha-warmup steps |
7. Design Decisions and Rationale
7.1 Why cross-attention instead of concatenation
Concatenation treats all spatial positions equally, but the polarization signal is meaningful only in glass regions. Cross-attention lets stereo features actively “query” polarization features; attention weights automatically concentrate on relevant regions (glass), effectively letting the network learn “where polarization information is needed” by itself.
7.2 Why Pol→Stereo unidirectional attention
The final task is stereo matching, which needs “stereo features augmented by polarization information.” Stereo serves as Query and polarization as Key/Value—i.e., “stereo queries polarization.” The reverse (enable_stereo_to_pol) is disabled by default to avoid unnecessary parameters and training complexity.
7.3 Why PooledCrossAttention is needed
Full-resolution attention has O(N^2) memory, which is infeasible on a 640×480 feature map. Pooling K, V to H/8 × W/8 reduces complexity to O(N*M), saving about 64× memory and making attention practical on high-resolution feature maps. Polarization features themselves are relatively “macroscopic” (glass regions are large), so downsampling K, V loses little effective information.
7.4 Why SafeCrossAttentionFusion needs multiple safeguards
Cross-attention is a newly added, randomly initialized module stacked on top of a strong SceneFlow pre-trained backbone; full injection from the start would disturb the backbone. The safeguards let the attention branch start from “almost no effect” and gradually open up during training:
- Zero Init + Alpha Gate ≈ 0.007: early in training, the attention branch has nearly no impact on the output.
- Alpha Cap Warmup: the cap grows linearly from 0.05 to 1.0, avoiding violent early perturbation.
- Residual Connection: a safety net; even if attention fails, the original stereo features are preserved.
7.5 Why NaN Protection is needed
When the K/V for some Query position are all invalid after masking, the entire attention row is -inf, softmax produces 0/0 = NaN, and the loss becomes NaN. The three lines of defense (conditional mask application, numerically stable softmax, nan_to_num safety net) ensure no NaN loss.
8. Highlights
- Spatially selective polarization fusion: cross-attention replaces uniform concat; stereo features actively query polarization features, and attention weights automatically concentrate in glass regions—the network learns “where to use polarization” by itself.
- Memory-efficient pooled attention: PooledCrossAttention pools K, V to
H/8 × W/8, reducing memory complexity fromO(N^2)toO(N*M)and saving about 64×, making attention practical on high-resolution feature maps. - Learnable gating plus warmup for gradual injection:
alphais a learnable gate, then multiplied by a linearly growingalpha_cap, letting the attention branch open up gradually from “nearly no effect” and avoiding disturbance to the pre-trained backbone. - Multiple training-stability safeguards: Zero Init, Alpha Gate ≈ 0.007, Alpha Cap Warmup, and Residual Connection stack to ensure the new attention module converges stably on top of the pre-trained backbone.
- Three lines of NaN defense: conditional mask application, numerically stable softmax, and
nan_to_numsafety net eliminate NaN loss caused by fully-infmask rows. - Unidirectional design aligned with the task: Pol→Stereo unidirectional attention maps directly to the final task—“stereo features to be augmented by polarization”—and avoids the extra parameters and complexity of reverse attention.