Blueprint · 2026

Cross-Attention Fusion Architecture

A specification of the Cross-Attention Fusion mechanism: cross-attention replaces `concat + conv` fusion, letting stereo features actively "query" polarization features and automatically learn where polarization information is needed.

  • stereo matching
  • polarization
  • RAFT-Stereo

Using these blueprints

Everything here is an architecture proposal I designed and chose to publish openly. Free to use, adapt, or build on — no permission needed.

If one turns out useful and crediting is convenient, a link back to this site is appreciated. It's never required.

1. Design Goals

The stereo_feat from the stereo stream and the pol_feat from the polarization stream must be fused into a single feature for the subsequent matching pipeline. The concat + conv fusion is:

fused = torch.cat([stereo_feat, pol_feat], dim=1)  # (B, 256, H, W)
fused = self.fusion_conv(fused)                     # (B, 128, H, W)

Problem: Concatenation treats all spatial positions equally, but the polarization signal is meaningful only in glass regions. Mixing polarization features uniformly into every pixel requires the network to figure out, on its own, “where in this mixed soup of features is the polarization information useful.”

Design goal: Use cross-attention so stereo features can “query” polarization features, automatically learning where polarization information is needed. Attention naturally concentrates weight on relevant regions (glass) and drives weight toward 0 elsewhere—effectively letting the network learn “where to use polarization.”


2. Architecture

Pol→Stereo cross-attention fusion

  • Q (Query): from stereo features.
  • K, V (Key, Value): from polarization features.
  • The output is residual: stereo_out = stereo_feat + alpha * attended_stereo.
  • alpha is a learnable gate, then multiplied by a training-warmup cap alpha_cap.

This architecture uses Pol→Stereo unidirectional attention: only stereo features query polarization information; the reverse direction (stereo→pol) is disabled by default.


3. Components and Modules

3.1 PooledCrossAttention (memory-efficient)

Vanilla attention has memory complexity O(N^2) (N = H*W), which is infeasible on high-resolution feature maps. PooledCrossAttention pools and downsamples K and V:

# Vanilla attention: O(H*W * H*W) = O(N^2)
# Pooled attention:  O(H*W * pool_h*pool_w) = O(N * M), M << N
# Memory savings ~64x (pool_size=8)

Q: (B, C, H, W) - keep original resolution
K, V: pooled to (B, C, H/8, W/8) - reduce memory
  • Q keeps the original resolution (every pixel can query).
  • K, V are pooled to H/8 × W/8 (pool_size=8).
  • Complexity drops from O(N^2) to O(N*M) with M << N, saving roughly 64× memory.

3.2 SafeCrossAttentionFusion (stable training)

The attention fusion layer contains multiple safeguards to ensure stable training:

  • Alpha Gate: alpha is initialized as sigmoid(-5) ≈ 0.007, near 0—early in training, the attention output is hardly injected.
  • Zero Init: the output projection is initialized to 0—early in training, the attention branch has no effect.
  • Alpha Cap Warmup: the alpha cap grows linearly from 0.05 to 1.0 (see §3.3).
  • Residual Connection: out = stereo + alpha * attended—even if the attention branch fails, the original stereo features are preserved.
  • NaN Protection: numerically stable softmax + nan_to_num (see §3.4).

3.3 Alpha Cap Warmup

The cap on alpha grows linearly with the training step:

def _compute_alpha_cap(self) -> float:
    if self.global_step >= self.args.alpha_cap_warmup:
        return 1.0
    progress = self.global_step / self.args.alpha_cap_warmup
    return self.args.alpha_cap_start + progress * (1.0 - self.args.alpha_cap_start)
    # 0.05 -> 1.0 over 5000 steps
  • Starts at alpha_cap_start = 0.05, ends at 1.0, growing linearly over alpha_cap_warmup = 5000 steps.
  • Limits the attention branch’s maximum influence early in training, preventing the randomly initialized attention from disturbing the pre-trained backbone.

3.4 NaN Protection

When the attention mask is entirely -inf, softmax produces NaN (0/0). The safeguards:

# 1. Apply the mask only when enough valid values exist
valid_count = (mask_pooled >= 0.5).sum(dim=-1, keepdim=True)
if valid_count.min() > 0:
    attn = attn.masked_fill(mask_pooled < 0.5, float('-inf'))

# 2. Numerically stable softmax
attn_max = attn.max(dim=-1, keepdim=True)[0]
attn = attn - attn_max
attn = F.softmax(attn, dim=-1)

# 3. Last line of defense
attn = torch.nan_to_num(attn, nan=0.0)

Three lines of defense: (1) apply the mask only when there are enough valid values; (2) subtract the max before softmax for numerical stability; (3) nan_to_num as a final safety net.


4. Data Flow

  1. The stereo stream produces stereo_feat; the polarization stream (PolarizationEncoder) produces pol_feat.
  2. PooledCrossAttention:
    • Qstereo_feat (keep original resolution H × W).
    • K, Vpol_feat, pooled to H/8 × W/8.
  3. Compute attention scores → NaN-protected softmax → weighted aggregation of V → attended_stereo.
  4. SafeCrossAttentionFusion: stereo_out = stereo_feat + alpha * attended_stereo, where alpha = sigmoid(learnable_logit) * alpha_cap and alpha_cap grows with the warmup.
  5. stereo_out enters the correlation pyramid + GRU for iterative refinement → disparity.

5. Tensor Dimensions

TensorDimensionsDescription
stereo_feat(B, C, H, W)Stereo-stream features
pol_feat(B, C, H, W)Polarization-stream features
Q(B, C, H, W)Projected from stereo_feat; keeps original resolution
K, V(B, C, H/8, W/8)Projected from pol_feat and pooled (pool_size=8)
attended_stereo(B, C, H, W)Attention output
stereo_out(B, C, H, W)stereo_feat + alpha * attended_stereo

Model parameter count

ModelTotal paramsTrainable params
Cross-Attention Fusion4,363,7154,363,715

Parameters are concentrated in PooledCrossAttention (Q, K, V projections) and SafeCrossAttentionFusion (alpha gates, output projections).


6. Hyperparameters

Cross-Attention-Specific Parameters

ParameterDefaultDescription
num_heads4Number of multi-head attention heads
pool_size8K, V pooling size (8 = 64× memory savings)
enable_stereo_to_polFalseWhether to enable bidirectional attention (unidirectional by default)
alpha_cap_start0.05Alpha-warmup starting value
alpha_cap_warmup5000Alpha-warmup steps

Training Hyperparameters

ParameterValueDescription
pretrainedraftstereo-sceneflow.pthSceneFlow pre-trained weights
pol_dim128Polarization feature dimension
pol_threshold0.05Soft Threshold value
pol_weight2.0Polarization weight in Polarization-aware Loss
glass_weight5.0Loss weight on glass regions
pol_lr_mult5.0Learning-rate multiplier for polarization layers
hidden_dim128GRU hidden state dimension
context_dim128Context dimension
feature_dim128Feature dimension
iters24GRU iterations
batch_size8Training batch size
num_steps70000Training steps
lr0.0003Learning rate
num_heads4Attention heads
pool_size8K, V pooling size
alpha_cap_start0.05Alpha-warmup starting value
alpha_cap_warmup5000Alpha-warmup steps

7. Design Decisions and Rationale

7.1 Why cross-attention instead of concatenation

Concatenation treats all spatial positions equally, but the polarization signal is meaningful only in glass regions. Cross-attention lets stereo features actively “query” polarization features; attention weights automatically concentrate on relevant regions (glass), effectively letting the network learn “where polarization information is needed” by itself.

7.2 Why Pol→Stereo unidirectional attention

The final task is stereo matching, which needs “stereo features augmented by polarization information.” Stereo serves as Query and polarization as Key/Value—i.e., “stereo queries polarization.” The reverse (enable_stereo_to_pol) is disabled by default to avoid unnecessary parameters and training complexity.

7.3 Why PooledCrossAttention is needed

Full-resolution attention has O(N^2) memory, which is infeasible on a 640×480 feature map. Pooling K, V to H/8 × W/8 reduces complexity to O(N*M), saving about 64× memory and making attention practical on high-resolution feature maps. Polarization features themselves are relatively “macroscopic” (glass regions are large), so downsampling K, V loses little effective information.

7.4 Why SafeCrossAttentionFusion needs multiple safeguards

Cross-attention is a newly added, randomly initialized module stacked on top of a strong SceneFlow pre-trained backbone; full injection from the start would disturb the backbone. The safeguards let the attention branch start from “almost no effect” and gradually open up during training:

  • Zero Init + Alpha Gate ≈ 0.007: early in training, the attention branch has nearly no impact on the output.
  • Alpha Cap Warmup: the cap grows linearly from 0.05 to 1.0, avoiding violent early perturbation.
  • Residual Connection: a safety net; even if attention fails, the original stereo features are preserved.

7.5 Why NaN Protection is needed

When the K/V for some Query position are all invalid after masking, the entire attention row is -inf, softmax produces 0/0 = NaN, and the loss becomes NaN. The three lines of defense (conditional mask application, numerically stable softmax, nan_to_num safety net) ensure no NaN loss.


8. Highlights

  • Spatially selective polarization fusion: cross-attention replaces uniform concat; stereo features actively query polarization features, and attention weights automatically concentrate in glass regions—the network learns “where to use polarization” by itself.
  • Memory-efficient pooled attention: PooledCrossAttention pools K, V to H/8 × W/8, reducing memory complexity from O(N^2) to O(N*M) and saving about 64×, making attention practical on high-resolution feature maps.
  • Learnable gating plus warmup for gradual injection: alpha is a learnable gate, then multiplied by a linearly growing alpha_cap, letting the attention branch open up gradually from “nearly no effect” and avoiding disturbance to the pre-trained backbone.
  • Multiple training-stability safeguards: Zero Init, Alpha Gate ≈ 0.007, Alpha Cap Warmup, and Residual Connection stack to ensure the new attention module converges stably on top of the pre-trained backbone.
  • Three lines of NaN defense: conditional mask application, numerically stable softmax, and nan_to_num safety net eliminate NaN loss caused by fully -inf mask rows.
  • Unidirectional design aligned with the task: Pol→Stereo unidirectional attention maps directly to the final task—“stereo features to be augmented by polarization”—and avoids the extra parameters and complexity of reverse attention.

← All blueprints