1. Design Goals
Letting polarization information enter feature extraction is a valuable direction, but the choice of injection method directly determines whether pretrained weights can be preserved.
If fnet’s input is simply changed to a 6-channel concat, the channel count of conv1 changes, and the pretrained weights trained for a 3-channel input become unusable. The entire feature extraction must be relearned, damaging the base geometric capability.
The design goal of this architecture is to adopt a non-destructive injection method along the correct direction of “polarization enters feature extraction”. The concrete approach is to leave the main branch’s 3-channel input and pretrained conv1 untouched, and instead introduce an independent polarization side branch that adds polarization features through soft additive fusion.
The core design is as follows:
The whole structure is split into two branches:
- Main branch: original 3-channel image →
conv1(uses pretrained) → 64-channel features. - Pol side branch:
pol_diff(3 channels) →pol_conv1(random init, learned independently) → 64-channel polarization features. - The two are combined via soft additive fusion
x + pol_scale * pol_feat, then jointly passed through the shared downstream layerslayer1 → layer2 → layer3 → out.
pol_scale (e.g. 0.1) is a small coefficient that lets the pretrained main branch dominate while the pol side branch only provides a small supplementary contribution.
The design principle this architecture follows: polarization does not enter the main input of fnet (which would break its pretrained structure); instead, it is fused additively through an independent branch, so polarization can influence feature matching without damaging existing geometric capability.
2. Architecture (Data Flow)
The left and right images each call FeatureEncoderWithPolFusion once (with their respective img and the shared pol_diff), producing fmap1 / fmap2, which then feed into the downstream CorrBlock and UpdateBlock (original RAFT).
3. Components and Modules
FeatureEncoderWithPolFusion
class FeatureEncoderWithPolFusion(nn.Module):
def __init__(self, output_dim=128, pol_scale=0.1):
# Main branch (keeps pretrained)
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
# Pol side branch (learned independently)
self.pol_conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
# Shared downstream layers
self.layer1, self.layer2, self.layer3 = ...
def forward(self, img, pol_diff):
x = self.relu1(self.norm1(self.conv1(img))) # pretrained
pol_feat = self.pol_relu1(self.pol_norm1(self.pol_conv1(pol_diff))) # random init
x = x + self.pol_scale * pol_feat # soft additive fusion
return self.conv_out(self.layer3(self.layer2(self.layer1(x))))
Item-by-item explanation:
conv1: the first layer of the main branch,Conv2d(3, 64, k=7, stride=2, padding=3). The input remains 3 channels, so pretrained weights can be loaded in full.pol_conv1: the first layer of the pol side branch, same structure asconv1(Conv2d(3, 64, k=7, stride=2, padding=3)), but with random init, learning the polarization representation independently.layer1/layer2/layer3: downstream layers shared between the main branch and the fused output.forward(img, pol_diff):x = relu1(norm1(conv1(img)))— main-branch features (pretrained).pol_feat = pol_relu1(pol_norm1(pol_conv1(pol_diff)))— pol-branch features (random init).x = x + pol_scale * pol_feat— soft additive fusion;pol_scale=0.1lets the pretrained branch dominate.- The fused
xflows throughlayer1 → layer2 → layer3 → conv_out.
4. Design Principles
| Principle | Description |
|---|---|
Fully preserve pretrained fnet.conv1 | Main-branch input stays 3-channel; geometric capability is not harmed |
Independent pol_conv1 learns the polarization representation | The pol side branch is decoupled from the main branch; each learns its own representation |
Soft additive fusion (pol_scale=0.1) | Addition + small coefficient lets the pretrained branch dominate while pol only supplements |
| Shared downstream layers (layer1/2/3) | The fused features share downstream layers, allowing polarization information to integrate gradually with network depth |
5. Tensor Dimensions
| Tensor | Dimensions | Description |
|---|---|---|
img | (B, 3, H, W) | Original RGB image, main-branch input |
pol_diff | (B, 3, H, W) | Polarization difference, pol-branch input |
conv1 output x | (B, 64, H/2, W/2) | stride=2, main-branch features |
pol_conv1 output pol_feat | (B, 64, H/2, W/2) | Pol-branch features |
Fused x | (B, 64, H/2, W/2) | x + pol_scale * pol_feat |
conv_out output fmap | Determined by output_dim=128 | Final features |
pol_scale | scalar | Default 0.1 |
6. Hyperparameters
| Hyperparameter | Value | Description |
|---|---|---|
pol_scale | 0.1 | Pol-branch fusion coefficient; lets pretrained branch dominate |
pol_levels | 4 | Pol pyramid levels |
pol_radius | 4 | Pol lookup radius |
iters | 24 | GRU iterations |
| curriculum | enabled | Curriculum training schedule |
7. Design Decisions and Rationale
| Decision | Rationale |
|---|---|
| Main branch keeps 3-channel input | Allows conv1 to fully load pretrained weights; preserves geometric capability |
Add independent pol_conv1 (random init) | Decouples pol representation from the main branch; each learns independently without disturbing initialization |
pol_conv1 has the same structure as conv1 | Both branch features share the same dimensionality (64ch) so they can be added directly |
| Soft additive fusion instead of concat | Concat would change conv1’s channel count; addition does not, and preserves the main-branch structure |
Small coefficient pol_scale=0.1 | Lets pretrained branch dominate; pol only provides a small supplement to avoid destabilizing the system |
| Downstream layers layer1/2/3 shared | Polarization information integrates progressively with network depth; no need to re-inject at every layer |
8. Highlights
- Non-destructive early fusion: replaces channel concat with dual-branch soft additive fusion; the main branch’s
conv1fully retains pretrained weights, preserving geometric capability. - Decoupling polarization and geometry: an independent
pol_conv1(random init) is dedicated to learning the polarization representation, without interfering with the pretrained main branch’s initialization. - Controllable polarization strength: the small coefficient
pol_scale=0.1lets the pretrained branch dominate while polarization provides only a small supplement, balancing stability and polarization contribution. - Progressive depth-wise fusion: after fusion the branches share
layer1/2/3, allowing polarization information to integrate gradually with network depth without needing repeated injection at every layer.