Blueprint · 2026

S2M2 6-Channel Input Architecture

This document describes the S2M2 6-channel input architecture (`s2m2_6ch.py`, `CNNEncoder6ch`, `S2M2_6ch`, `load_pretrained_3ch`) and the 6ch fixes relative to the original S2M2. It focuses on the architectural design itself and does not cover experimental results or performance numbers.

  • stereo matching
  • polarization
  • RAFT-Stereo

Using these blueprints

Everything here is an architecture proposal I designed and chose to publish openly. Free to use, adapt, or build on — no permission needed.

If one turns out useful and crediting is convenient, a link back to this site is appreciated. It's never required.

1. Design Goals

If a stereo matching architecture wants to use the polarization pair [I∥, I⊥], the most direct approach is to put the polarization pair into the input layer, without attaching a correction module or extra training stages outside the backbone. The 6-Channel Input architecture takes this most concise route: directly modify the Feature Pyramid’s input layer to 6 channels and feed both I∥ and I⊥ at once.

Design goals:

  • No extra correction module is required.
  • Polarization information is fused from the very first layer of feature extraction.
  • The model decides on its own how to use the polarization information.
  • The architecture is as concise as possible, with minimal changes to the original S2M2.

2. Architecture: 6-Channel Input

S2M2 6-Channel input architecture


3. Components and Modules

Implementation file: s2m2_6ch.py

ComponentResponsibility
CNNEncoder6chCNN encoder with conv0 modified to take a 6-channel input
S2M2_6chS2M2 variant that uses CNNEncoder6ch
load_pretrained_3ch()Loads 3ch weights and expands them to 6ch (duplicate + /2)

3.1 CNNEncoder6ch

CNNEncoder6ch changes the first convolution conv0 of the original CNNEncoder from a 3-channel input to a 6-channel input, so that the polarization pair [I∥, I⊥] is fused into feature extraction from the very first layer.

3.2 S2M2_6ch

S2M2_6ch is the S2M2 variant that uses CNNEncoder6ch. The subsequent components (Unet / MRT / DispInit / Refiner / Upsampling, etc.) are taken from the original S2M2 without modification.

3.3 load_pretrained_3ch()

load_pretrained_3ch() is used to expand the original 3-channel pretrained weights to 6 channels:

  • Strategy: duplicate + /2 — copy the 3ch conv0 weights twice and concatenate them into 6ch, then divide the weights by 2.
  • Rationale: duplicating and dividing by 2 makes the 6ch conv0 output equivalent to the original 3ch conv0 under a duplicated [RGB, RGB] input, preserving pretrained properties as the training starting point.

4. 6ch Fixes Relative to the Original S2M2

When s2m2_6ch.py is refactored, the following two issues must be fixed relative to the original S2M2.

4.1 Fix 1: Transformer Call Convention

# Wrong: pass a list
self.transformer([left_py_4x, left_py_8x, ...], [right_py_4x, ...])

# Correct: concat in batch, then pass 4 pyramid levels
imgs = torch.cat([left, right], dim=0)  # batch concat
feature_tr_4x = self.transformer(feature_py_4x, feature_py_8x, feature_py_16x, feature_py_32x)

The Transformer does not accept a list. The left and right images should first be concatenated along the batch dimension, then the 4 pyramid levels (4x / 8x / 16x / 32x) should be passed in separately.

4.2 Fix 2: UpsampleMask1x Expects 3ch RGB

# Wrong: pass 6ch
filter_weights = self.upsample_mask_1x(disp_up, img0_nor, feature0_2x)

# Correct: pass only the first 3ch for edge detection
filter_weights = self.upsample_mask_1x(disp_up, img0_nor[:, :3], feature0_2x)

Design decision: edge-guided sharpening only needs the RGB of I∥; the polarization information is already inside the CNN features. UpsampleMask1x expects 3ch RGB for edge detection, so only the first 3 channels (I∥) are taken.


5. Tensor Dimensions

TensorShapeDescription
6ch input(B, 6, H, W)[I∥ (3ch), I⊥ (3ch)]
Stage 1 input(B, 6, H, W)[RGB, RGB]; the two halves are identical
Stage 2 input(B, 6, H, W)[I∥, I⊥]; the two halves differ
conv0 input (CNNEncoder6ch)6 channelsoriginal was 3 channels
Transformer input4 pyramid levelsfeature_py_4x / 8x / 16x / 32x; left and right concatenated along batch
Image input of upsample_mask_1ximg0_nor[:, :3]only the first 3ch (I∥) for edge detection
pretrained 3ch conv0 weights3 channelsexpanded to 6 channels via duplicate + /2

6. Polarization Injection Points

Injection point: the CNNEncoder input layer (injection point A).

The polarization pair [I∥, I⊥] is fused at the first convolution conv0 of CNNEncoder6ch, so polarization information is carried from the very start of feature extraction. Subsequently:

  • MRT cross-attention naturally learns the polarization difference between the left and right halves.
  • The model decides on its own how to use the polarization information; no extra correction module is needed.

7. Design Decisions and Rationale

DecisionRationale
Use 6-channel inputMost concise; no extra correction module; polarization is fused from the first layer
Modify only conv0 (CNNEncoder6ch)Minimal change; the subsequent Unet/MRT/Refiner are reused from the original
load_pretrained_3ch uses duplicate + /2Equivalent to the original 3ch under a duplicated [RGB, RGB] input, preserving pretrained properties
Stage 1 uses duplicated [RGB, RGB] inputThe model first learns basic stereo matching; both halves are identical
Stage 2 uses [I∥, I⊥]The model sees two different halves for the first time; MRT cross-attention learns the polarization difference
Transformer changed to take 4 pyramid levelsThe original Transformer does not accept a list; batch concat is required and levels must be passed in separately
UpsampleMask1x only takes the first 3ch (I∥)Edge-guided sharpening only needs I∥ RGB; the polarization information is already in the CNN features

8. Highlights

  • Most concise polarization injection: changing only the one layer conv0 is enough to fuse the polarization pair from the start of feature extraction; no correction module or extra training stage is needed.
  • Weight-equivalent pretrained transfer: load_pretrained_3ch uses duplicate + /2 so that under a duplicated [RGB, RGB] input the 6ch conv0 is equivalent to the original 3ch, enabling painless continuation of pretrained features.
  • Model decides how to use polarization: once polarization is fused from the first layer, MRT cross-attention naturally learns the difference between the left and right halves; no handcrafted usage is needed.
  • Stage 1 and Stage 2 share the same input format: both stages use (B, 6, H, W); the only difference is whether the two halves are identical, allowing curriculum training to proceed seamlessly on the same architecture.

← All blueprints