Source code for stable_pretraining.methods.salt

"""SALT: Static-teacher Asymmetric Latent Training.

SALT combines ideas from V-JEPA masking with MAE pixel reconstruction
(Stage 1) and latent target prediction with a frozen teacher (Stage 2).

References:
    Li, Xianhang, et al. "Rethinking JEPA: Compute-Efficient Video SSL
    with Frozen Teachers." 2025.
    https://arxiv.org/pdf/2509.24317

Example:
    from stable_pretraining.methods import SALT, MAE
    from stable_pretraining.backbone import MultiBlockMasking

    # Stage 1: MAE with multi-block masking
    stage1 = MAE("vit_tiny_patch16_224", masking=MultiBlockMasking())

    # Stage 2: SALT from Stage 1 checkpoint
    stage2 = SALT.from_checkpoint(
        "stage1.ckpt",
        encoder_name="vit_tiny_patch16_224",
        predictor_embed_dim=384,
        predictor_depth=12,
    )

"""

from dataclasses import dataclass
from typing import Optional, Tuple

import torch
import torch.nn.functional as F

from stable_pretraining.backbone import (
    EvalOnly,
    FlexibleTransformer,
    MaskedEncoder,
)
from stable_pretraining.data.masking import multi_block_mask
from stable_pretraining import Module
from transformers.utils import ModelOutput


[docs] @dataclass class SALTOutput(ModelOutput): """Output from SALT forward pass. :ivar loss: Prediction loss (L1 between predicted and teacher latents, 0 in eval) :ivar embedding: CLS token embedding [B, D] :ivar predictions: Predicted representations [B, N_tgt, D] (or None in eval) :ivar targets: Teacher target representations [B, N_tgt, D] (or None in eval) :ivar num_targets: Number of target patches (0 in eval) :ivar num_context: Number of context patches (all patches in eval) """ loss: torch.Tensor = None embedding: torch.Tensor = None predictions: Optional[torch.Tensor] = None targets: Optional[torch.Tensor] = None num_targets: int = None num_context: int = None
[docs] class SALT(Module): """SALT Stage 2: Static-teacher Asymmetric Latent Training. Architecture: - **Teacher** (frozen): Encodes full unmasked image via EvalOnly(MaskedEncoder) - **Student** (trainable): Encodes only context (visible) patches - **Predictor**: Lightweight transformer predicting teacher latents at target positions :param encoder_name: timm model name (e.g., "vit_tiny_patch16_224") :param predictor_embed_dim: Predictor hidden dimension (default: 384) :param predictor_depth: Number of predictor blocks (default: 12) :param predictor_num_heads: Number of predictor attention heads (default: 16) :param num_targets: Number of target blocks for masking (default: 4) :param context_scale: (min, max) scale for context block :param target_scale: (min, max) scale for each target block :param context_aspect_ratio: (min, max) aspect ratio for context block :param target_aspect_ratio: (min, max) aspect ratio for target blocks :param teacher_state_dict: Optional state dict to load into teacher encoder :param pretrained: Load pretrained encoder weights Example:: model = SALT("vit_tiny_patch16_224") images = torch.randn(4, 3, 224, 224) model.train() output = model(images) output.loss.backward() model.eval() output = model(images) features = output.embedding # [B, D] """ def __init__( self, encoder_name: str = "vit_tiny_patch16_224", predictor_embed_dim: int = 384, predictor_depth: int = 12, predictor_num_heads: int = 16, num_targets: int = 4, context_scale: Tuple[float, float] = (0.85, 1.0), target_scale: Tuple[float, float] = (0.15, 0.2), context_aspect_ratio: Tuple[float, float] = (1.0, 1.0), target_aspect_ratio: Tuple[float, float] = (0.75, 1.5), teacher_state_dict: dict = None, pretrained: bool = False, ): super().__init__() # Frozen teacher teacher_encoder = MaskedEncoder( encoder_name, masking=None, pretrained=pretrained ) if teacher_state_dict is not None: teacher_encoder.load_state_dict(teacher_state_dict) self.teacher = EvalOnly(teacher_encoder) # Trainable student (no masking — we handle it manually) self.student = MaskedEncoder(encoder_name, masking=None, pretrained=pretrained) embed_dim = self.student.embed_dim num_patches = self.student.default_grid_h * self.student.default_grid_w # Predictor with mask token for target queries self.predictor = FlexibleTransformer( input_dim=embed_dim, hidden_dim=predictor_embed_dim, output_dim=embed_dim, num_patches=num_patches, depth=predictor_depth, num_heads=predictor_num_heads, self_attn=True, cross_attn=False, add_mask_token=True, use_adaln=False, num_prefix_tokens=0, zero_init_output=False, ) # Masking parameters self.num_targets = num_targets self.context_scale = context_scale self.target_scale = target_scale self.context_aspect_ratio = context_aspect_ratio self.target_aspect_ratio = target_aspect_ratio self.embed_dim = embed_dim def _generate_masks( self, grid_h: int, grid_w: int, device: torch.device, ) -> Tuple[torch.Tensor, torch.Tensor]: """Generate multi-block masks for context and targets. :param grid_h: Patch grid height :param grid_w: Patch grid width :param device: Target device :return: (context_idx [N_ctx], target_idx [N_tgt]) — 1D index tensors """ block_scales = [self.context_scale] + [self.target_scale] * self.num_targets aspect_ratios = [self.context_aspect_ratio] + [ self.target_aspect_ratio ] * self.num_targets masks = multi_block_mask( grid_h, grid_w, block_scales=block_scales, aspect_ratios=aspect_ratios, ) context_mask = masks[0] # [H, W], 1=in block target_masks = masks[1:] # Make context disjoint from targets for t in target_masks: context_mask = context_mask * (1 - t) # Flatten and compute indices context_flat = context_mask.flatten().bool() target_flat = torch.zeros(grid_h * grid_w, dtype=torch.bool) for t in target_masks: target_flat = target_flat | t.flatten().bool() context_idx = context_flat.nonzero(as_tuple=True)[0].to(device) target_idx = target_flat.nonzero(as_tuple=True)[0].to(device) return context_idx, target_idx def _encode( self, patches: torch.Tensor, indices: torch.Tensor, grid_h: int, grid_w: int, encoder: MaskedEncoder, ) -> torch.Tensor: """Encode patches at specified indices through an encoder. Handles positional embeddings and prefix tokens (CLS). :param patches: All patch embeddings [B, N, D] :param indices: Indices to encode [B, K] or [K] (will be expanded) :param grid_h: Patch grid height :param grid_w: Patch grid width :param encoder: MaskedEncoder instance :return: Encoded representations [B, num_prefix + K, D] """ B, _, D = patches.shape # Expand 1D indices to batch dimension if indices.dim() == 1: indices = indices.unsqueeze(0).expand(B, -1) # Add positional embeddings to patches prefix_pos, patch_pos = encoder._get_pos_embed(grid_h, grid_w) x = patches + patch_pos.expand(B, -1, -1) # Gather visible patches x = torch.gather(x, 1, indices.unsqueeze(-1).expand(-1, -1, D)) # Prepend prefix tokens (CLS, registers) prefix = encoder._get_prefix_tokens(B) if prefix is not None: if prefix_pos is not None and not encoder.no_embed_class: prefix = prefix + prefix_pos x = torch.cat([prefix, x], dim=1) x = encoder.vit.pos_drop(x) x = encoder.vit.blocks(x) x = encoder.vit.norm(x) return x
[docs] def forward(self, images: torch.Tensor) -> SALTOutput: """Forward pass. Training: teacher encodes full image, student encodes context only, predictor predicts teacher latents at target positions, L1 loss. Eval: student encodes full image, returns CLS token embedding, zero loss. :param images: Input images [B, C, H, W] :return: SALTOutput """ B = images.shape[0] if not self.training: with torch.no_grad(): student_out = self.student(images) embedding = student_out.encoded[:, 0, :].detach() return SALTOutput( loss=torch.tensor(0.0, device=images.device), embedding=embedding, predictions=None, targets=None, num_targets=0, num_context=student_out.encoded.shape[1] - self.student.num_prefix_tokens, ) grid_h, grid_w = self.student._get_grid_size(images) context_idx, target_idx = self._generate_masks(grid_h, grid_w, images.device) N_tgt = target_idx.shape[0] # === Teacher forward (frozen, full image) === with torch.no_grad(): teacher_out = self.teacher(images) teacher_patches = teacher_out.encoded[ :, self.teacher.num_prefix_tokens :, : ] # Gather target latents tgt_expand = ( target_idx.unsqueeze(0) .unsqueeze(-1) .expand(B, -1, teacher_patches.shape[-1]) ) teacher_targets = torch.gather(teacher_patches, 1, tgt_expand) # === Student forward (context patches only) === student_patches = self.student.patch_embed(images) encoded = self._encode( student_patches, context_idx, grid_h, grid_w, self.student ) student_context = encoded[:, self.student.num_prefix_tokens :, :] # === Predictor forward === # Zero queries at target positions, mask_token replaces them queries = torch.zeros( B, N_tgt, self.embed_dim, device=images.device, dtype=student_context.dtype ) query_mask = torch.ones(B, N_tgt, device=images.device, dtype=torch.bool) ctx_idx_batch = context_idx.unsqueeze(0).expand(B, -1) tgt_idx_batch = target_idx.unsqueeze(0).expand(B, -1) predictions = self.predictor( context=student_context, queries=queries, context_idx=ctx_idx_batch, query_idx=tgt_idx_batch, query_mask=query_mask, ) # === Loss === loss = F.l1_loss(predictions, teacher_targets) embedding = encoded[:, 0, :].detach() return SALTOutput( loss=loss, embedding=embedding, predictions=predictions, targets=teacher_targets, num_targets=N_tgt, num_context=context_idx.shape[0], )
[docs] @classmethod def from_checkpoint( cls, ckpt_path: str, encoder_name: str = "vit_tiny_patch16_224", **kwargs, ) -> "SALT": """Create SALT Stage 2 from a Stage 1 (MAE/VPixel) checkpoint. Loads the encoder weights from Stage 1 as the frozen teacher. :param ckpt_path: Path to Stage 1 checkpoint :param encoder_name: timm model name matching Stage 1 :param kwargs: Additional arguments for SALT.__init__ :return: SALT instance with teacher initialized from checkpoint """ state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True) if "state_dict" in state_dict: state_dict = state_dict["state_dict"] # Extract encoder weights (handles both "encoder." prefix from MAE # and direct state dict) encoder_state = {} for k, v in state_dict.items(): if k.startswith("encoder."): encoder_state[k.removeprefix("encoder.")] = v if not encoder_state: # Try using the state dict directly (e.g., if saved without prefix) encoder_state = state_dict return cls( encoder_name=encoder_name, teacher_state_dict=encoder_state, **kwargs, )