1. Design Goals
1.1 The 65 mm baseline alignment problem
The dual-stream polarization stereo-matching architecture needs to compare I∥ and I⊥ when computing polarization features:
|I∥ - I⊥| → polarization difference features
But the two cameras have a 65 mm baseline; comparing them directly at the same pixel coordinates compares different 3D points:
Left camera (I∥): sees point A
Right camera (I⊥): at the same pixel location, sees point B (shifted by ~60 px)
Therefore, before computing the polarization difference, the right image must be warped into the left view, and warping requires disparity.
1.2 The training–inference asymmetry
| Stage | GT Disparity | Handling |
|---|---|---|
| Training | Yes | Warp with GT disparity for alignment |
| Inference | No | ??? |
At training time, GT disparity can warp the right image into the left view to obtain a “pure polarization difference.” At inference time, there is no GT disparity and alignment cannot be done. The design goal of this architecture is to bridge this gap at inference.
1.3 Approach: self-align using “estimated disparity”
The core idea of two-stage inference:
- First, use unaligned polarization features to obtain an initial disparity estimate.
- Use this estimated disparity to warp and recompute aligned polarization features.
- Continue iterating, using the more accurate polarization features to refine disparity.
That is, “rough estimate → align using the rough estimate → refine.”
2. Architecture
2.1 warp_with_disparity utility
def warp_with_disparity(
img: torch.Tensor,
disparity: torch.Tensor,
return_valid_mask: bool = True
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Warp the right image into the left view using a disparity map.
Formula: warped(x, y) = img(x - disparity, y)
Returns:
warped: warped right image (aligned to the left view)
valid_mask: validity mask (0 outside boundaries)
"""
After warping, regions outside the boundary are filled with 0, so the polarization difference computed there is a fake signal; hence a validity mask is needed:
valid_mask = (xx_warped >= 0) & (xx_warped <= W - 1)
Uses of valid_mask:
- Training: compute loss only over valid regions.
- Inference: ignore polarization features in invalid regions.
3. Components and Modules (two inference methods)
3.1 Method 1: forward_inference (efficient, recommended)
Updates polarization features only 1–2 times; the GRU state stays continuous:
def forward_inference(
self,
left: torch.Tensor,
right: torch.Tensor,
iters: Optional[int] = None,
pol_update_iters: Optional[List[int]] = None, # default [iters//2]
) -> torch.Tensor:
Flow:
- Start iterating with unaligned polarization features.
- At a designated iteration (e.g., the 6th), warp using the current disparity estimate.
- Recompute aligned polarization features and rebuild the correlation volume.
- Continue the remaining iterations.
Advantages:
fnetis computed only once (stereo features are not recomputed).- The GRU hidden state stays continuous (no restart).
- Compute cost is about 1.1–1.2× a base forward pass.
3.2 Method 2: forward_two_pass (full)
A full two-stage pass; each stage is independent:
def forward_two_pass(
self,
left: torch.Tensor,
right: torch.Tensor,
iters_pass1: int = 6,
iters_pass2: int = 6,
) -> torch.Tensor:
Flow:
- Pass 1: no polarization alignment →
disparity_v1. - Pass 2: warp with
disparity_v1→ correct polarization features →disparity_v2.
Advantages: more accurate polarization alignment; suitable when precision is the priority.
Disadvantages: compute cost is 2×.
4. Data Flow
forward_inference data flow
forward_two_pass data flow
5. Tensor Dimensions
| Tensor | Dimensions / Type | Description |
|---|---|---|
left, right | torch.Tensor, (B, C, H, W) | Input left/right images |
disparity | torch.Tensor, (B, 1, H, W) | Disparity map |
warped | Same as img | Warped right image (aligned to left view) |
valid_mask | Optional[torch.Tensor], (B, 1, H, W) | Validity mask (0 outside boundary) |
pol_update_iters | Optional[List[int]] | List of iteration indices at which polarization features are updated; default [iters//2] |
| Output disparity | torch.Tensor | Matches input resolution |
6. Hyperparameters
| Parameter | Default | Applies to | Description |
|---|---|---|---|
iters | Same as training (e.g., 24) | forward_inference | Total GRU iterations |
pol_update_iters | [iters//2] | forward_inference | When polarization features are updated (iteration index) |
iters_pass1 | 6 | forward_two_pass | First-stage iterations |
iters_pass2 | 6 | forward_two_pass | Second-stage iterations |
return_valid_mask | True | warp_with_disparity | Whether to return the validity mask |
Usage Recommendations
| Scenario | Recommended Method | Description |
|---|---|---|
| Training | forward(..., disparity_gt=gt) | Use GT disparity |
| Inference (real-time) | forward_inference(..., pol_update_iters=[6]) | Balance efficiency and accuracy |
| Inference (high precision) | forward_two_pass(...) | Highest accuracy |
7. Design Decisions and Rationale
7.1 Why two inference methods
forward_two_pass is the most direct concept (two fully independent stages) but costs 2× compute. forward_inference compresses the cost to about 1.1–1.2× by computing fnet once and keeping the GRU hidden state continuous, making it the recommended choice for real-time inference. Offering both lets the user pick “efficiency first” or “precision first.”
7.2 Why forward_inference keeps the GRU continuous
forward_two_pass restarts the GRU in Pass 2, discarding the hidden state accumulated in Pass 1. forward_inference does not restart the GRU; it only “swaps the polarization features and correlation volume” at designated iterations, letting the GRU keep refining on a continuous hidden state and avoiding recomputation of stereo features.
7.3 Why valid_mask is needed
Regions outside the boundary are filled with 0 after warping, so the “polarization difference” computed there is a fake signal. valid_mask marks the valid range: training excludes invalid-region loss and inference ignores invalid-region polarization features, preventing fake signals from contaminating results.
7.4 Why pol_update_iters defaults to the midpoint
The alignment quality of polarization features depends on the disparity-estimate quality. Updating too early (disparity is still coarse) gives poor alignment; updating too late leaves too few iterations to benefit from the aligned features. The default [iters//2] (midpoint) is a compromise between efficiency and accuracy. It can also be changed to multiple updates (e.g., [6, 12, 18]) for progressive alignment.
8. Highlights
- Self-alignment using estimated disparity: when no GT disparity is available at inference, first use unaligned features for a coarse disparity estimate, then warp with that estimate to align polarization features—bridging the alignment gap between training and inference.
- Continuous-GRU design for efficient inference:
forward_inferencedoes not restart the GRU and computesfnetonly once, swapping polarization features and the correlation volume only at designated iterations, compressing the cost of two-stage alignment to about 1.1–1.2×. - Twin efficiency/precision tracks: both an efficient and a full two-pass inference method are provided, letting the user freely choose based on real-time or precision needs.
- Validity mask against boundary fakes:
warp_with_disparityreturns avalid_maskthat marks invalid regions outside the warp boundary, preventing fake polarization differences from zero-fill areas from contaminating loss and inference results. - Tunable polarization-update timing:
pol_update_itersexposes the timing as a hyperparameter, defaulting to the iteration midpoint as a compromise between alignment quality and refinement time, and supports multiple progressive updates.