1. Design Goals
In polarized stereo matching, a common polarization stereo module handles the polarization signal in a simple manner:
pol_diff = left - right (simple subtraction)
pol_corr = query(pol_volume, disp)
output = concat(corr, pol_corr, disp) → encoder → GRU
This approach has the following issues:
- No learnable polarization feature extraction.
- No spatial attention mechanism (the model does not know where the glass is).
- Simple concat fusion, with no learned weighting between stereo and pol.
The design goal of this architecture is to address these three issues by introducing a spatial attention mechanism and learnable fusion: the model can mark glass locations and dynamically decide whether to trust stereo or pol information.
The core consists of two modules:
- Polarization Attention: Generates a spatial attention map from
pol_corrto mark glass locations. - Gated Fusion: Learns dynamic fusion weights between stereo and pol features.
Two optional extension modules can also be added:
- Learnable Pol Encoder (optional): Uses a Conv encoder to learn richer polarization features rather than relying solely on raw subtraction.
- Glass-aware Auxiliary Head (optional): Predicts a glass mask as an auxiliary task, useful for visualization and interpretability.
2. Architecture
Data Flow
pol_diffentersPolarizationAttention, which producespol_attention_mapvia Conv + sigmoid (marking glass locations).GatedFusionreceives three inputs:corr,pol_corr, andpol_attention_map.- The three are concatenated and used to compute the gate (sigmoid, [0,1]).
corrandpol_corrare each enhanced by 1×1 convolution intoenhanced_corr/enhanced_pol.fused = gate * enhanced_corr + (1-gate) * enhanced_polperforms weighted fusion.concat(fused, disp)is fed into the GRU.
3. Components and Modules
3.1 PolarizationAttention
class PolarizationAttention(nn.Module):
"""Generate a spatial attention map from pol_diff"""
def __init__(self, in_channels, reduction=4):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, in_channels // reduction, 1),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels // reduction, 1, 1),
nn.Sigmoid()
)
def forward(self, pol_corr):
return self.conv(pol_corr) # (B, 1, H, W)
- Two 1×1 convolutions, with channel compression by
reduction=4in between. - Final Sigmoid produces a single-channel spatial attention map in [0,1].
- Output shape:
(B, 1, H, W). - Intuition: Regions with high polarization difference (glass) receive higher weights, so the model knows “where to trust the polarization information”.
3.2 GatedFusion
class GatedFusion(nn.Module):
"""Learn the fusion weight between stereo and pol"""
def __init__(self, corr_dim, pol_dim, out_dim):
super().__init__()
self.gate_net = nn.Sequential(
nn.Conv2d(corr_dim + pol_dim + 1, 64, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 1, 1),
nn.Sigmoid()
)
self.corr_enhance = nn.Conv2d(corr_dim, out_dim, 1)
self.pol_enhance = nn.Conv2d(pol_dim, out_dim, 1)
def forward(self, corr, pol_corr, pol_attn):
gate = self.gate_net(torch.cat([corr, pol_corr, pol_attn], dim=1))
corr_feat = self.corr_enhance(corr)
pol_feat = self.pol_enhance(pol_corr)
fused = gate * corr_feat + (1 - gate) * pol_feat
return fused, gate
gate_net: Input channels arecorr_dim + pol_dim + 1(+1 for the single-channelpol_attn); the sequence is 3×3 Conv → ReLU → 1×1 Conv → Sigmoid, producing a single-channel gate.corr_enhance/pol_enhance: Each is a 1×1 convolution projecting corr / pol toout_dim.forwardreturnsfusedandgate(the gate can be visualized).- Gate semantics:
gate = 1 → trust stereo,gate = 0 → trust pol; in glass regions, the stereo weight is automatically reduced and the pol weight is increased.
3.3 UpdateBlockV2
Built on the standard RAFT UpdateBlock and integrates PolarizationAttention and GatedFusion, concatenating the fused features with disp as the GRU input.
3.4 Learnable Pol Encoder (Optional)
Uses a Conv encoder to learn a richer polarization feature representation rather than relying solely on raw subtraction.
3.5 Glass-aware Auxiliary Head (Optional)
Predicts glass locations as an auxiliary task, useful for visualization and interpretability.
4. Tensor Dimensions
| Tensor | Shape | Description |
|---|---|---|
pol_diff / pol_corr input | (B, in_channels, H, W) | Input to PolarizationAttention |
pol_attention_map | (B, 1, H, W) | Spatial attention map |
gate_net input | (B, corr_dim + pol_dim + 1, H, W) | concat of corr/pol/attn |
gate | (B, 1, H, W) | Fusion weight |
corr_feat / pol_feat | (B, out_dim, H, W) | Enhanced features |
fused | (B, out_dim, H, W) | Fused output |
5. Hyperparameters
| Hyperparameter | Value | Description |
|---|---|---|
pol_levels | 4 | Number of pyramid levels in the polarization volume |
pol_radius | 4 | Lookup radius of the polarization volume |
iters | 24 | Number of GRU iterations |
reduction | 4 | Channel compression ratio in PolarizationAttention |
6. Design Decisions and Rationale
| Decision | Rationale |
|---|---|
Introduce PolarizationAttention | Produces a spatial attention map so the model knows “where the glass is” |
| Use 1×1 conv with reduction compression | Keeps the attention module lightweight |
Introduce GatedFusion | Makes stereo / pol fusion weights learnable, replacing simple concat |
| Gate takes corr, pol, and attn together | Fusion decisions consider all three sources, dynamically choosing which to trust |
corr_enhance / pol_enhance project to the same out_dim | Both streams must share the same dimension to be combined by a weighted sum |
| Learnable Encoder and Auxiliary Head are optional | Core is attention + gated fusion; the encoder is a capacity extension, the head is for interpretability |
7. Highlights
- Uses
PolarizationAttentionto derive a spatial attention map from the polarization difference, explicitly marking glass locations so the model “knows where to trust polarization information”. GatedFusionlearns a [0,1] gate that dynamically weights the stereo and pol streams, replacing static concat fusion.- The gate input simultaneously covers corr, pol_corr, and the attention map, giving fusion decisions complete contextual information.
- The attention module stays lightweight via 1×1 convolutions and channel reduction; the gate output can be directly visualized, providing interpretability.
- Provides two optional modules — Learnable Pol Encoder and Glass-aware Auxiliary Head — to extend polarization feature depth and auxiliary supervision on top of the core architecture.