Blueprint · 2026

Context-Modulated FNet Architecture

This document is the standalone specification of the "Context-Modulated FNet" stereo matching architecture. It describes the architecture itself (design goals, data flow, components and code, tensor dimensions, design decisions, polarization injection points, and Knowledge Transfer).

  • stereo matching
  • polarization
  • RAFT-Stereo

Using these blueprints

Everything here is an architecture proposal I designed and chose to publish openly. Free to use, adapt, or build on — no permission needed.

If one turns out useful and crediting is convenient, a link back to this site is appreciated. It's never required.

1. Design Goals

Building on Dual Volume (Cost Volume + Pol Volume complementarity) and two-stage training, this architecture provides two core mechanisms:

  • Context-Modulated FNet (FiLM): lets the context (which carries the glass information produced by the Pol CNet) modulate the FNet’s feature extraction via FiLM, so that the Correlation Volume becomes meaningful even on glass regions.
  • Knowledge Transfer fix: changes the Pol CNet’s pretrain input from [GT_mask, zeros] to [soft_mask, edges], whose statistical distribution is much closer to the finetune input, enabling effective transfer of weights from Pretrain to Finetune.

The core problem being solved is that photometric inconsistency on glass distorts the features extracted by the FNet, causing the downstream Correlation Volume to produce garbage signals. This architecture intervenes during feature extraction so that the FNet, modulated by context, extracts features that are meaningful for matching even on glass regions. In addition, if the statistical distributions of the Pol CNet inputs in Pretrain and Finetune differ too much, the weights learned in Pretrain cannot transfer effectively. The fix is therefore to bring the Pretrain input format close to the Finetune format from the start.

Furthermore, the timing of the Pol Volume computation is explicitly defined as Phase 0 (the earliest), ensuring that it reflects the raw polarization difference and is not contaminated by context modulation.


2. Architecture (with Data Flow)

Context-Modulated FNet five-phase architecture

Five-Phase Data Flow

PhaseNameHighlights
Phase 0Pol Volume computed firstFrom raw images, `pol = 1 -
Phase 1Context EncodingRGB CNet + Pol CNet → Context Fusion → context
Phase 2Context-Modulated Feature ExtractionFiLM: features' = γ * features + β
Phase 3Correlation VolumeBuild CorrBlock from the modulated fmap
Phase 4GRU iteration (Dual Volume)Dual-volume lookup of corr_feat + pol_feat

3. Components and Modules (with Code)

3.1 PolVolumeBlock

Design rationale:

  • The Pol Volume must be computed first, since it derives from the raw images (raw polarization difference).
  • It cannot be built after the FNet, otherwise it would be affected by context modulation.
  • The semantic design is aligned with that of the Correlation Volume (high value = good).

Core code:

class PolVolumeBlock:
    """
    Pol Volume computation and lookup

    Design considerations (semantically aligned with the Correlation Volume):
    - Correlation Volume: dot product; high value = similar features = good match
    - Pol Volume: polarization consistency; high value = polarization consistent
      (non-glass), low value = large polarization difference (glass)

    Benefits of this design:
    - Non-glass region: Corr high + Pol high -> both signals agree, GRU uses
      them confidently
    - Glass region: Corr low (garbage) + Pol low -> both signals say
      "something is wrong here"
    - The GRU can simply concat the two volumes without any extra fusion logic
    """

    def __init__(self, img_left: torch.Tensor, img_right: torch.Tensor,
                 num_levels: int = 4, radius: int = 4):
        B, C, H, W = img_left.shape

        # Convert to grayscale
        if C == 3:
            weights = torch.tensor([0.299, 0.587, 0.114], ...)
            left_gray = (img_left * weights.view(1, 3, 1, 1)).sum(dim=1, keepdim=True)
            right_gray = (img_right * weights.view(1, 3, 1, 1)).sum(dim=1, keepdim=True)
        else:
            left_gray = img_left
            right_gray = img_right

        # Downsample to 1/4 (aligned with the fmap resolution of CorrBlock)
        self.left_small = F.avg_pool2d(left_gray, kernel_size=4, stride=4)
        self.right_small = F.avg_pool2d(right_gray, kernel_size=4, stride=4)

        # Record maximum intensity for normalization
        self.max_intensity = max(self.left_small.max(), self.right_small.max())
        self.radius = radius

    def __call__(self, disp: torch.Tensor) -> torch.Tensor:
        """
        Lookup the Pol Volume given the current disparity estimate.

        Returns:
            pol_feat: [B, 2*radius+1, H/4, W/4] polarization consistency feature
        """
        B, _, H, W = self.left_small.shape
        disp_small = disp / 4.0  # adjust to 1/4 resolution

        pol_samples = []
        dx = torch.linspace(-self.radius, self.radius, 2 * self.radius + 1)

        for offset in dx:
            # Compute sampling position
            disp_offset = disp_small + offset
            grid_x = ...  # omitted; see full code
            grid = torch.stack([grid_x, grid_y], dim=-1)

            # Sample right image
            right_sampled = F.grid_sample(self.right_small, grid,
                                          mode='bilinear',
                                          padding_mode='zeros',
                                          align_corners=True)

            # Polarization difference
            pol_diff = torch.abs(self.left_small - right_sampled)

            # Normalize to [0, 1]
            pol_diff_norm = pol_diff / (self.max_intensity + 1e-6)

            # Invert the semantics: high value = polarization consistent
            # (aligned with Correlation semantics)
            pol_consistency = 1.0 - pol_diff_norm

            pol_samples.append(pol_consistency)

        # Stack: [B, 2*radius+1, H, W]
        return torch.cat(pol_samples, dim=1)

Dimension alignment:

VolumeInput resolutionOutput channelsSemantics
CorrBlockH/4 × W/42×radius+1 = 9High = good match
PolVolumeBlockH/4 × W/42×radius+1 = 9High = polarization consistent

3.2 UpdateBlock

Structure:

  • Includes a pol_dim parameter.
  • motion_input is [corr_feat, pol_feat, disp].

Core code:

class UpdateBlock(nn.Module):
    """
    GRU update module

    Inputs:
    - hidden: GRU hidden state [B, hidden_dim, H/4, W/4]
    - context: fused context [B, context_dim, H/4, W/4]
    - corr_feat: Correlation Volume lookup [B, corr_dim, H/4, W/4]
    - pol_feat: Pol Volume lookup [B, pol_dim, H/4, W/4]
    - disp: current disparity estimate [B, 1, H/4, W/4]
    """

    def __init__(self, hidden_dim: int = 128, context_dim: int = 64,
                 corr_dim: int = 9, pol_dim: int = 9):
        super().__init__()

        # Motion encoder: corr + pol + disp
        input_dim = corr_dim + pol_dim + 1  # 9 + 9 + 1 = 19
        self.motion_encoder = nn.Sequential(
            nn.Conv2d(input_dim, 128, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 64, 3, padding=1),
        )

        # GRU
        self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=64 + context_dim)

        # Disparity head
        self.disp_head = nn.Sequential(
            nn.Conv2d(hidden_dim, 128, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 1, 3, padding=1),
        )

    def forward(self, hidden, context, corr_feat, pol_feat, disp):
        # Dual Volume fusion (concat)
        motion_input = torch.cat([corr_feat, pol_feat, disp], dim=1)
        motion_features = self.motion_encoder(motion_input)

        # GRU update
        gru_input = torch.cat([motion_features, context], dim=1)
        hidden = self.gru(hidden, gru_input)

        # Predict disparity residual
        delta_disp = self.disp_head(hidden)

        return hidden, delta_disp

3.3 ContextModulatedFNet (FiLM)

Design rationale:

  • The context knows “where the glass is” (information from the Pol CNet).
  • FiLM lets the FNet adjust its feature extraction according to the context.
  • Different features are extracted over glass regions, making the Correlation Volume meaningful there too.

Core code:

class ContextModulatedFNet(nn.Module):
    """
    Context-Modulated Feature Network

    Uses FiLM (Feature-wise Linear Modulation) to let the context modulate
    feature extraction:
        modulated = gamma * features + beta

    where gamma and beta are generated from the context.
    """

    def __init__(self, in_channels: int = 3, out_channels: int = 128,
                 context_dim: int = 64):
        super().__init__()

        # Backbone: extract base features
        self.backbone = nn.Sequential(
            nn.Conv2d(in_channels, 64, 7, stride=2, padding=3),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, out_channels, 3, padding=1),
        )

        # FiLM generator: context -> (gamma, beta)
        self.film_generator = nn.Sequential(
            nn.Conv2d(context_dim, 128, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, out_channels * 2, 1),  # gamma + beta
        )

    def forward(self, img: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
        # Base features
        features = self.backbone(img)  # [B, C, H/4, W/4]

        # FiLM parameters
        film_params = self.film_generator(context)  # [B, 2C, H/4, W/4]
        gamma, beta = film_params.chunk(2, dim=1)

        # FiLM modulation
        modulated = gamma * features + beta

        return modulated

Effect of FiLM modulation:

RegionContext signalFiLM effect
Non-glasspol_ctx ≈ 0γ ≈ 1, β ≈ 0 → original feature preserved
Glasspol_ctx highγ, β adjust → feature modulated

3.4 ContextModulatedStereo.forward()

Core code:

class ContextModulatedStereo(nn.Module):
    def forward(self, left, right, pol_input, iters=None, test_mode=False):
        if iters is None:
            iters = self.default_iters

        # ============================================================
        # Phase 0: Pol Volume (earliest! from raw images)
        # ============================================================
        pol_block = PolVolumeBlock(left, right, radius=self.pol_radius)

        # ============================================================
        # Phase 1: Context Encoding
        # ============================================================
        # RGB Context
        rgb_ctx, hidden = self.rgb_cnet(left)

        # Pol Context (input depends on the training stage)
        pol_ctx = self.pol_cnet(pol_input)

        # Context Fusion
        context = self.context_fusion(rgb_ctx, pol_ctx)

        # ============================================================
        # Phase 2: Context-Modulated Feature Extraction
        # ============================================================
        fmap_left = self.fnet(left, context)    # FiLM modulation!
        fmap_right = self.fnet(right, context)  # same context

        # ============================================================
        # Phase 3: Correlation Volume
        # ============================================================
        corr_block = CorrBlock(fmap_left, fmap_right,
                               num_levels=self.corr_levels,
                               radius=self.corr_radius)

        # ============================================================
        # Phase 4: GRU iteration (Dual Volume)
        # ============================================================
        B, _, H, W = left.shape
        disp = torch.zeros(B, 1, H // 4, W // 4, device=left.device)

        disp_predictions = []

        for _ in range(iters):
            # Dual Volume lookup
            corr_feat = corr_block(disp)   # [B, 9, H/4, W/4]
            pol_feat = pol_block(disp)     # [B, 9, H/4, W/4]

            # GRU update
            hidden, delta_disp = self.update_block(
                hidden, context, corr_feat, pol_feat, disp
            )

            # Update disparity
            disp = disp + delta_disp

            # Upsample to the original resolution
            disp_up = 4 * F.interpolate(
                disp, size=(H, W), mode='bilinear', align_corners=True
            )
            disp_predictions.append(disp_up)

        if test_mode:
            return disp_predictions[-1]
        return disp_predictions

4. Knowledge Transfer: Pretrain → Finetune

Problem Analysis

If the Pol CNet inputs of the two stages take the following form:

Pretrain:  pol_input = [GT_mask (binary 0/1), zeros]
Finetune:  pol_input = [pol_max (continuous 0~1), pol_var (continuous)]

Problem: the statistical distributions differ too much.

  • GT_mask: binary, sharp edges.
  • pol_max: continuous, with smooth gradients.
  • zeros vs pol_var: entirely different signals.

Consequence: weights the Pol CNet learns in Pretrain cannot transfer effectively to Finetune.

Solution: make prepare_pol_input_pretrain output a format whose statistical distribution is closer to Finetune.

Core Code

def prepare_pol_input_pretrain(glass_mask: torch.Tensor,
                               training: bool = True) -> torch.Tensor:
    """
    Pol CNet input for the Pretrain stage

    Design considerations (to ensure Pretrain -> Finetune knowledge transfer):
    - Channel 0: softened GT_mask, mimicking the continuous nature of pol_max
    - Channel 1: edge information, mimicking pol_var (variance is high at edges)

    This way, the feature extraction the Pol CNet learns can transfer to
    the Finetune stage.
    """
    B, _, H, W = glass_mask.shape
    device = glass_mask.device

    # === Channel 0: softened mask (mimics the continuity of pol_max) ===
    if training:
        # Add noise
        noise = torch.randn_like(glass_mask) * 0.05
        ch0 = (glass_mask + noise).clamp(0, 1)

        # Slight blur (3x3 avg pool with padding)
        ch0 = F.avg_pool2d(
            F.pad(ch0, (1, 1, 1, 1), mode='replicate'),
            kernel_size=3, stride=1
        )
    else:
        # No noise at validation time
        ch0 = glass_mask

    # === Channel 1: edge detection (mimics pol_var) ===
    # pol_var is typically high at glass edges (different disparities
    # produce different polarization differences)

    # Sobel kernels
    sobel_x = torch.tensor([[-1, 0, 1],
                            [-2, 0, 2],
                            [-1, 0, 1]], dtype=torch.float32, device=device)
    sobel_y = torch.tensor([[-1, -2, -1],
                            [ 0,  0,  0],
                            [ 1,  2,  1]], dtype=torch.float32, device=device)

    sobel_x = sobel_x.view(1, 1, 3, 3)
    sobel_y = sobel_y.view(1, 1, 3, 3)

    # Edge gradients
    gx = F.conv2d(glass_mask, sobel_x, padding=1)
    gy = F.conv2d(glass_mask, sobel_y, padding=1)

    # Gradient magnitude
    ch1 = torch.sqrt(gx ** 2 + gy ** 2 + 1e-6)

    # Normalize to [0, 1]
    ch1_max = ch1.amax(dim=(2, 3), keepdim=True)
    ch1 = ch1 / (ch1_max + 1e-6)

    # Combine
    return torch.cat([ch0, ch1], dim=1)

Comparison Table

StageChannel 0Channel 1Statistics
Pretrain (binary input)GT_mask (binary)zerosDiscrete, no signal
Pretrain (continuous input)soft_mask (continuous + noise)edges (continuous)Continuous, structured
Finetunepol_max (continuous)pol_var (continuous)Continuous, structured

Visualization

Pretrain (binary):    Pretrain (continuous): Finetune:
┌───────────────┐    ┌───────────────┐    ┌───────────────┐
│ ch0: ██████   │    │ ch0: ▓▓▓▓▓▓   │    │ ch0: ▒▒▒▒▒▒   │
│      ██████   │    │      ▓▓▓▓▓▓   │    │      ▓▓▓▓▓▓   │
│      (binary) │    │  (soft+noise) │    │   (pol_max)   │
├───────────────┤    ├───────────────┤    ├───────────────┤
│ ch1: 000000   │    │ ch1: ░░████░░ │    │ ch1: ░░▓▓▓░░  │
│      000000   │    │      (edges)  │    │   (pol_var)   │
│    (zeros)    │    │               │    │               │
└───────────────┘    └───────────────┘    └───────────────┘
      ↓ transfer fails     ↓ transferable ✓

Corresponding Setting in the Training Script

_get_pol_input gains a training flag:

def _get_pol_input(self, batch, training=True):
    """
    Prepare the Pol CNet input depending on the training stage.

    Args:
        batch: dict containing 'glass_mask', 'left', 'right'
        training: whether in training mode (affects noise in pretrain)
    """
    if self.args.stage == 'pretrain':
        # Pretrain: use the (softened) GT mask
        pol_input = prepare_pol_input_pretrain(
            batch['glass_mask'],
            training=training  # add noise during training, none for validation
        )
    else:
        # Finetune: use Pol Volume statistics
        pol_input = prepare_pol_input_finetune(
            batch['left'],
            batch['right'],
            max_disp=self.args.max_disp
        )
    return pol_input

5. Tensor Dimensions

TensorDimensionsDescription
left / right(B, 3, H, W)Raw polarization image pair
left_small / right_small(B, 1, H/4, W/4)PolVolumeBlock internal grayscale + avg_pool2d(k=4,s=4)
pol_feat (pol_block lookup)[B, 2×radius+1, H/4, W/4] = [B, 9, H/4, W/4]radius = 4
pol_input(B, 2, ·, ·)2 channels (Pretrain: [soft_mask, edges]; Finetune: [pol_max, pol_var])
rgb_ctx / hiddenrgb_cnet outputhidden: [B, hidden_dim=128, H/4, W/4]
pol_ctxpol_cnet outputcontext_dim = 64
context[B, context_dim, H/4, W/4]Output of Context Fusion
backbone features (FNet)[B, out_channels=128, H/4, W/4]conv1(7,s2) → conv(3,s2) → conv(3)
film_params[B, 2×out_channels, H/4, W/4]gamma + beta
gamma / betaeach [B, out_channels, H/4, W/4]film_params.chunk(2, dim=1)
fmap_left / fmap_right[B, 128, H/4, W/4]After FiLM modulation
corr_feat (corr_block lookup)[B, 9, H/4, W/4]2×radius+1 = 9
motion_input[B, 19, H/4, W/4]corr_dim 9 + pol_dim 9 + disp 1
motion_features[B, 64, H/4, W/4]Output of motion_encoder
gru_input[B, 64 + context_dim, H/4, W/4]cat(motion_features, context)
disp[B, 1, H/4, W/4]Initialized to zeros
delta_disp[B, 1, H/4, W/4]Output of disp_head
disp_up[B, 1, H, W]4 × F.interpolate(disp, (H,W))

UpdateBlock Dimension Derivation

input_dim = corr_dim + pol_dim + 1 = 9 + 9 + 1 = 19; GRU input dimension = 64 + context_dim.


6. Hyperparameters

ParameterDefaultDescription
pol_radius4Lookup radius of the Pol Volume; determines pol_feat channels = 2×radius+1 = 9

7. Design Decisions and Rationale

7.1 Why Correlation Instead of Cost?

Traditional stereo matching uses Cost Volumes (SAD/SSD), where lower is better; this architecture uses a Correlation Volume (dot product), where higher is better.

TypeComputationSemanticsRange
Cost (SAD)ΣL − R
Cost (SSD)Σ(L − R)²Low = good match[0, ∞)
Correlationdot(L, R)High = good match(−∞, ∞)

Reasons for choosing Correlation: features are already normalized, so the dot-product semantics are clean; gradients are more stable; consistent with optical flow architectures.

The Pol Volume semantics are aligned with this:

# Raw polarization difference
pol_diff = |left - right|  # high = glass

# Inverted to align with Correlation
pol_consistency = 1 - pol_diff  # high = non-glass = good

7.2 Why Must the Pol Volume Be Computed First?

The Pol Volume must reflect the “raw polarization difference” and must not be affected by context modulation. The ordering is:

1. Pol Volume <- raw left, right
2. Context <- RGB CNet + Pol CNet
3. FNet <- modulated by context
4. Corr Volume <- modulated fmap
5. GRU <- Corr + Pol (two complementary signals)

If the Pol Volume were built after the FNet: it would be affected by the context, lose the meaning of “raw polarization signal”, and the two volumes received by the GRU would be coupled, losing their complementarity.

7.3 Why Use Concat Instead of More Complex Fusion?

Options: 1. Concat (chosen) 2. Cross-attention 3. Gated fusion 4. Weighted sum.

Reasons for choosing Concat:

  • Simple and stable: the gains from more complex fusion do not scale with their complexity.
  • Debug-friendly: per-channel activations can be inspected directly.
  • Let the network learn: the motion encoder learns how to weigh the two signals.
  • Avoid overfitting: complex fusion easily overfits on small datasets.

7.4 Design Validation Checklist

  • Pol Volume computed first (Phase 0).
  • Pol Volume uses raw images.
  • FNet modulated by context (FiLM).
  • GRU receives the Dual Volume (Corr + Pol).
  • Pretrain input statistics are close to Finetune.
  • Correlation and Pol Volume semantics are aligned (high = good).

8. Polarization Injection Points

This architecture has three polarization entry points:

Injection PointPhaseFormDescription
Pol Volume → GRUBuilt at Phase 0, used at Phase 4pol_feat = pol_block(disp), concatenated with corr_feat into motion_encoderDirectly complements the Cost Volume, looked up every iteration
pol_input → Pol CNet → contextPhase 1pol_ctx = pol_cnet(pol_input), then Context FusionPolarization enters the context
context → FiLM → FNetPhase 2modulated = γ * features + β, γ/β generated from contextPolarization information (via context) indirectly modulates the FNet’s feature extraction, making the Correlation Volume meaningful on glass

The key injection point is the FiLM in Phase 2: the context (which includes the glass information from the Pol CNet) uses γ, β to modulate FNet features, so the Correlation Volume on glass regions also becomes meaningful for matching.


9. Highlights

  • FiLM feature-level modulation: the context (carrying glass information from the Pol CNet) intervenes during feature extraction via γ·feat + β, making the Correlation Volume on glass meaningful at its source rather than patching an already broken cost.
  • Pol Volume computed at Phase 0: deliberately placed earliest, ensuring it reflects the “raw polarization difference” without contamination from context modulation, preserving complementarity with the Cost Volume.
  • Semantic alignment design: the Pol Volume is inverted via 1 - |L-R| so that “high value = good” matches the Correlation Volume, allowing the GRU to simply concatenate the two volumes without extra fusion logic.
  • Knowledge Transfer fix: switches the Pretrain input from [binary mask, zeros] to [soft_mask, edges], deliberately approximating the continuous statistics of the Finetune input so that Pol CNet weights can transfer effectively.
  • Triple polarization injection: polarization information enters the network through three paths — the GRU loop (Pol Volume), the context (Pol CNet), and FNet modulation (FiLM) — operating at multiple levels.

← All blueprints