SALT#

class stable_pretraining.methods.SALT(encoder_name: str = 'vit_tiny_patch16_224', predictor_embed_dim: int = 384, predictor_depth: int = 12, predictor_num_heads: int = 16, num_targets: int = 4, context_scale: Tuple[float, float] = (0.85, 1.0), target_scale: Tuple[float, float] = (0.15, 0.2), context_aspect_ratio: Tuple[float, float] = (1.0, 1.0), target_aspect_ratio: Tuple[float, float] = (0.75, 1.5), teacher_state_dict: dict = None, pretrained: bool = False)[source]#

Bases: Module

SALT Stage 2: Static-teacher Asymmetric Latent Training.

Architecture:
  • Teacher (frozen): Encodes full unmasked image via EvalOnly(MaskedEncoder)

  • Student (trainable): Encodes only context (visible) patches

  • Predictor: Lightweight transformer predicting teacher latents at target positions

Parameters:
  • encoder_name – timm model name (e.g., “vit_tiny_patch16_224”)

  • predictor_embed_dim – Predictor hidden dimension (default: 384)

  • predictor_depth – Number of predictor blocks (default: 12)

  • predictor_num_heads – Number of predictor attention heads (default: 16)

  • num_targets – Number of target blocks for masking (default: 4)

  • context_scale – (min, max) scale for context block

  • target_scale – (min, max) scale for each target block

  • context_aspect_ratio – (min, max) aspect ratio for context block

  • target_aspect_ratio – (min, max) aspect ratio for target blocks

  • teacher_state_dict – Optional state dict to load into teacher encoder

  • pretrained – Load pretrained encoder weights

Example:

model = SALT("vit_tiny_patch16_224")
images = torch.randn(4, 3, 224, 224)

model.train()
output = model(images)
output.loss.backward()

model.eval()
output = model(images)
features = output.embedding  # [B, D]
forward(images: Tensor) SALTOutput[source]#

Forward pass.

Training: teacher encodes full image, student encodes context only, predictor predicts teacher latents at target positions, L1 loss.

Eval: student encodes full image, returns CLS token embedding, zero loss.

Parameters:

images – Input images [B, C, H, W]

Returns:

SALTOutput

classmethod from_checkpoint(ckpt_path: str, encoder_name: str = 'vit_tiny_patch16_224', **kwargs) SALT[source]#

Create SALT Stage 2 from a Stage 1 (MAE/VPixel) checkpoint.

Loads the encoder weights from Stage 1 as the frozen teacher.

Parameters:
  • ckpt_path – Path to Stage 1 checkpoint

  • encoder_name – timm model name matching Stage 1

  • kwargs – Additional arguments for SALT.__init__

Returns:

SALT instance with teacher initialized from checkpoint