SALT#
- class stable_pretraining.methods.SALT(encoder_name: str = 'vit_tiny_patch16_224', predictor_embed_dim: int = 384, predictor_depth: int = 12, predictor_num_heads: int = 16, num_targets: int = 4, context_scale: Tuple[float, float] = (0.85, 1.0), target_scale: Tuple[float, float] = (0.15, 0.2), context_aspect_ratio: Tuple[float, float] = (1.0, 1.0), target_aspect_ratio: Tuple[float, float] = (0.75, 1.5), teacher_state_dict: dict = None, pretrained: bool = False)[source]#
Bases:
ModuleSALT Stage 2: Static-teacher Asymmetric Latent Training.
- Architecture:
Teacher (frozen): Encodes full unmasked image via EvalOnly(MaskedEncoder)
Student (trainable): Encodes only context (visible) patches
Predictor: Lightweight transformer predicting teacher latents at target positions
- Parameters:
encoder_name – timm model name (e.g., “vit_tiny_patch16_224”)
predictor_embed_dim – Predictor hidden dimension (default: 384)
predictor_depth – Number of predictor blocks (default: 12)
predictor_num_heads – Number of predictor attention heads (default: 16)
num_targets – Number of target blocks for masking (default: 4)
context_scale – (min, max) scale for context block
target_scale – (min, max) scale for each target block
context_aspect_ratio – (min, max) aspect ratio for context block
target_aspect_ratio – (min, max) aspect ratio for target blocks
teacher_state_dict – Optional state dict to load into teacher encoder
pretrained – Load pretrained encoder weights
Example:
model = SALT("vit_tiny_patch16_224") images = torch.randn(4, 3, 224, 224) model.train() output = model(images) output.loss.backward() model.eval() output = model(images) features = output.embedding # [B, D]
- forward(images: Tensor) SALTOutput[source]#
Forward pass.
Training: teacher encodes full image, student encodes context only, predictor predicts teacher latents at target positions, L1 loss.
Eval: student encodes full image, returns CLS token embedding, zero loss.
- Parameters:
images – Input images [B, C, H, W]
- Returns:
SALTOutput
- classmethod from_checkpoint(ckpt_path: str, encoder_name: str = 'vit_tiny_patch16_224', **kwargs) SALT[source]#
Create SALT Stage 2 from a Stage 1 (MAE/VPixel) checkpoint.
Loads the encoder weights from Stage 1 as the frozen teacher.
- Parameters:
ckpt_path – Path to Stage 1 checkpoint
encoder_name – timm model name matching Stage 1
kwargs – Additional arguments for SALT.__init__
- Returns:
SALT instance with teacher initialized from checkpoint