stable_pretraining.methods package#
Submodules#
stable_pretraining.methods.ijepa module#
I-JEPA: Image-based Joint-Embedding Predictive Architecture.
Self-supervised learning via predicting target patch representations from context patch representations using a lightweight predictor.
References
Assran et al. “Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture.” CVPR 2023. https://arxiv.org/abs/2301.08243
Example:
from stable_pretraining.backbone import IJEPA
from stable_pretraining.callbacks import TeacherStudentCallback
import lightning as pl
# Create model
model = IJEPA(
model_or_model_name="vit_base_patch16_224",
predictor_embed_dim=384,
predictor_depth=6,
num_targets=4,
)
# Training with PyTorch Lightning
trainer = pl.Trainer(
max_epochs=300,
callbacks=[TeacherStudentCallback()], # Handles EMA updates
)
trainer.fit(model, dataloader)
# Get encoder for downstream tasks
encoder = model.encoder.student
- class stable_pretraining.methods.ijepa.IJEPA(model_or_model_name: str | Module = 'vit_base_patch16_224', predictor_embed_dim: int = 384, predictor_depth: int = 6, num_targets: int = 4, target_scale: Tuple[float, float] = (0.15, 0.2), target_aspect_ratio: Tuple[float, float] = (0.75, 1.5), context_scale: Tuple[float, float] = (0.85, 1.0), ema_decay_start: float = 0.996, ema_decay_end: float = 1.0, pretrained: bool = False)[source]#
Bases:
ModuleI-JEPA: Image-based Joint-Embedding Predictive Architecture.
- Architecture:
Context Encoder (student): Encodes visible/context patches
Target Encoder (teacher): EMA copy, encodes target patches
Predictor: Lightweight transformer predicting targets from context
The context encoder is wrapped with
TeacherStudentWrapper, enabling automatic EMA updates viaTeacherStudentCallback.- Parameters:
model_or_model_name – timm model name string or pre-instantiated nn.Module
predictor_embed_dim – Predictor hidden dimension (default: 384)
predictor_depth – Number of predictor blocks (default: 6)
num_targets – Number of target blocks to sample (default: 4)
target_scale – (min, max) fraction of patches per target block
target_aspect_ratio – (min, max) aspect ratio of target blocks
context_scale – (min, max) fraction of non-target patches as context
ema_decay_start – Initial EMA decay (default: 0.996)
ema_decay_end – Final EMA decay (default: 1.0)
pretrained – Load pretrained encoder weights
Example:
# Basic usage model = IJEPA("vit_base_patch16_224") images = torch.randn(4, 3, 224, 224) # Training mode: predicts masked targets model.train() output = model(images) output.loss.backward() # Eval mode: encodes all patches (no masking) model.eval() output = model(images) features = output.predictions # [B, N, D]
Example with Lightning:
import lightning as pl from stable_pretraining.callbacks import TeacherStudentCallback class IJEPALightning(pl.LightningModule): def __init__(self): super().__init__() self.model = IJEPA("vit_base_patch16_224") def training_step(self, batch, batch_idx): images = batch[0] if isinstance(batch, (list, tuple)) else batch output = self.model(images) self.log("loss", output.loss) return output.loss def configure_optimizers(self): return torch.optim.AdamW(self.parameters(), lr=1.5e-4) trainer = pl.Trainer(callbacks=[TeacherStudentCallback()]) trainer.fit(IJEPALightning(), dataloader)
Note
Use
TeacherStudentCallbackto handle EMA updates automaticallyIn eval mode,
num_targets=0and all patches are returned as contextAccess trained encoder via
model.encoder.student
- forward(images: Tensor, embedding_source: str = 'teacher') IJEPAOutput[source]#
Forward pass.
- In training mode:
Samples target blocks and context region via
IJEPAMaskingEncodes context through student, targets through teacher (EMA)
Predicts target representations from context
Returns smooth L1 loss between predictions and targets
- In eval mode:
No masking, all patches treated as context
Returns encoded features with zero loss
Always uses student encoder
- Parameters:
images – Input images [B, C, H, W]
embedding_source – Which encoder to use for the embedding output.
"teacher"(default) or"student". Only affects training mode; eval mode always uses student.
- Returns:
IJEPAOutputwith loss and representations
- class stable_pretraining.methods.ijepa.IJEPAOutput(loss: Tensor = None, embedding: Tensor = None, predictions: Tensor = None, targets: Tensor = None, num_targets: int = None, num_context: int = None)[source]#
Bases:
ModelOutputOutput from IJEPA forward pass.
- Variables:
loss – Prediction loss (0 in eval mode)
embedding – Patch embeddings [B, N, D] for downstream use
predictions – Predicted representations [B, N_tgt, D] (or context in eval)
targets – Target representations [B, N_tgt, D] (or context in eval)
num_targets – Number of target patches (0 in eval)
num_context – Number of context patches (all patches in eval)
stable_pretraining.methods.lejepa module#
LeJEPA: Latent Embedding Joint-Embedding Predictive Architecture.
Self-supervised learning via multi-view invariance combined with a sliced goodness-of-fit test (SIGReg) that pushes embeddings toward an isotropic Gaussian.
References
Balestriero & LeCun. “LeJEPA: Provable and Scalable Self-Supervised Learning Without the Heuristics.” 2025. https://arxiv.org/abs/2511.08544
Example:
from stable_pretraining.methods.lejepa import LeJEPA
model = LeJEPA("vit_small_patch16_224")
global_images = [torch.randn(4, 3, 224, 224)] * 2
all_images = [torch.randn(4, 3, 224, 224)] * 6
model.train()
output = model(global_images, all_images)
output.loss.backward()
model.eval()
output = model(images=torch.randn(4, 3, 224, 224))
features = output.embedding # [N, D]
- class stable_pretraining.methods.lejepa.EppsPulley(t_max: float = 3.0, n_points: int = 17)[source]#
Bases:
ModuleEpps-Pulley goodness-of-fit test for univariate normality.
Projects data onto a grid of points and computes the Epps-Pulley statistic.
- Parameters:
t_max – Integration upper bound.
n_points – Number of integration points.
- class stable_pretraining.methods.lejepa.LeJEPA(encoder_name: str = 'vit_base_patch16_224', projector: Module | None = None, n_slices: int = 1024, t_max: float = 3.0, n_points: int = 17, lamb: float = 0.02, pretrained: bool = False, drop_path_rate: float = 0.1)[source]#
Bases:
ModuleLeJEPA: multi-view invariance + sliced Epps-Pulley SIGReg.
- Architecture:
Backbone: timm ViT (CLS-pooled,
num_classes=0)Projector: MLP projection head
Loss:
invariance + (λ * SIGReg)
Centers are computed from global-view projections only. The invariance term penalises the MSE between each view’s projection and the center. The SIGReg term is a sliced goodness-of-fit test that pushes projected embeddings toward an isotropic Gaussian, averaged over views.
- Parameters:
encoder_name – timm model name (e.g.,
"vit_base_patch16_224")projector – Optional projection head. When
None, a 3-layer BN+ReLU MLP (embed_dim → 2048 → 2048 → 512) is created.n_slices – Random projection directions for the goodness-of-fit test (default: 1024)
t_max – EP integration upper bound (default: 3.0)
n_points – EP quadrature nodes (default: 17)
lamb – SIGReg weight λ (default: 0.02)
pretrained – Load pretrained timm weights
Example:
model = LeJEPA("vit_base_patch16_224") images = torch.randn(4, 3, 224, 224) model.train() output = model( global_views=[images, images], all_views=[images, images, images, images], ) output.loss.backward() model.eval() output = model(images=images) features = output.embedding # [4, 768]
Example with Lightning:
import lightning as pl from stable_pretraining.methods.lejepa import LeJEPA class LeJEPALightning(pl.LightningModule): def __init__(self): super().__init__() self.model = LeJEPA("vit_base_patch16_224") def training_step(self, batch, batch_idx): views = [v["image"] for v in batch["views"]] output = self.model(global_views=views, all_views=views) self.log("loss", output.loss) return output.loss def configure_optimizers(self): return torch.optim.AdamW(self.parameters(), lr=1e-3)
- forward(global_views: list[Tensor] | None = None, local_views: list[Tensor] | None = None, images: Tensor | None = None) LeJEPAOutput[source]#
Same as
torch.nn.Module.forward().- Parameters:
*args – Whatever you decide to pass into the forward method.
**kwargs – Keyword arguments are also possible.
- Returns:
Your model’s output
- class stable_pretraining.methods.lejepa.LeJEPAOutput(loss: Tensor = None, embedding: Tensor = None, inv_loss: Tensor = None, sigreg_loss: Tensor = None)[source]#
Bases:
ModelOutputOutput from LeJEPA forward pass.
- Variables:
loss – Combined invariance + SIGReg loss (0 in eval mode).
embedding – Backbone embeddings [V*N, D] (train) or [N, D] (eval).
inv_loss – Invariance component.
sigreg_loss – Epps-Pulley goodness-of-fit component.
- class stable_pretraining.methods.lejepa.SlicedEppsPulley(num_slices: int = 1024, t_max: float = 3.0, n_points: int = 17)[source]#
Bases:
ModuleSliced Epps-Pulley goodness-of-fit test for multivariate normality.
Projects data onto random 1-D directions and averages the univariate Epps-Pulley statistics. A synchronised step counter seeds the random projections so all DDP ranks sample identical directions.
- Parameters:
num_slices – Number of random 1-D projections.
t_max – EP integration upper bound.
n_points – EP quadrature nodes.
stable_pretraining.methods.mae module#
MAE: Masked Autoencoders Are Scalable Vision Learners.
Self-supervised learning via reconstructing masked patches from visible patches.
References
He et al. “Masked Autoencoders Are Scalable Vision Learners.” CVPR 2022. https://arxiv.org/abs/2111.06377
Example:
from stable_pretraining.backbone import MAE
import lightning as pl
# Create model
model = MAE("vit_base_patch16_224", mask_ratio=0.75)
# Training
model.train()
output = model(images)
output.loss.backward()
# Get encoder for downstream
encoder = model.encoder
- class stable_pretraining.methods.mae.MAE(model_or_model_name: str | Module = 'vit_base_patch16_224', decoder_embed_dim: int = 512, decoder_depth: int = 8, decoder_num_heads: int = 16, mask_ratio: float = 0.75, block_size: int = 1, norm_pix_loss: bool = True, loss_type: str = 'mse', pretrained: bool = False, masking: Module | None = None)[source]#
Bases:
ModuleMAE: Masked Autoencoders Are Scalable Vision Learners.
- Architecture:
Encoder: ViT processing only visible (unmasked) patches
Decoder: Lightweight transformer reconstructing masked patches
Target: Normalized pixel values of masked patches
- Parameters:
model_or_model_name – timm model name string or pre-instantiated nn.Module
decoder_embed_dim – Decoder hidden dimension (default: 512)
decoder_depth – Number of decoder blocks (default: 8)
decoder_num_heads – Decoder attention heads (default: 16)
mask_ratio – Fraction of patches to mask (default: 0.75)
block_size – Masking block size, 1=random (default: 1)
norm_pix_loss – Normalize target pixels per patch (default: True)
loss_type – Loss type for MAELoss (default: ‘mse’)
pretrained – Load pretrained encoder weights
masking – Custom masking module (e.g., MultiBlockMasking). When provided, overrides mask_ratio and block_size.
Example:
# Basic usage model = MAE("vit_base_patch16_224", mask_ratio=0.75) images = torch.randn(4, 3, 224, 224) model.train() output = model(images) output.loss.backward() model.eval() output = model(images) # Full reconstruction, zero loss
Example with Lightning:
class MAELightning(pl.LightningModule): def __init__(self): super().__init__() self.model = MAE("vit_base_patch16_224") def training_step(self, batch, batch_idx): images = batch[0] if isinstance(batch, (list, tuple)) else batch return self.model(images).loss def configure_optimizers(self): return torch.optim.AdamW(self.parameters(), lr=1.5e-4)
- class stable_pretraining.methods.mae.MAEOutput(loss: Tensor = None, predictions: Tensor = None, mask: Tensor = None, num_masked: int = None, num_visible: int = None)[source]#
Bases:
ModelOutputOutput from MAE forward pass.
- Variables:
loss – Reconstruction loss (MSE on masked patches)
predictions – Reconstructed patches [B, N, patch_dim]
mask – Binary mask where 1=masked, 0=visible [B, N]
num_masked – Number of masked patches
num_visible – Number of visible patches
stable_pretraining.methods.nepa module#
NEPA: Next-Embedding Predictive Autoregression.
- class stable_pretraining.methods.nepa.NEPA(img_size: int = 224, patch_size: int = 14, in_chans: int = 3, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4.0, drop_rate: float = 0.0, attn_drop_rate: float = 0.0, drop_path_rate: float = 0.0, use_rope: bool = True, use_qk_norm: bool = True, use_swiglu: bool = True, layer_scale_init: float = 1e-05)[source]#
Bases:
ModuleNEPA: Next-Embedding Predictive Autoregression.
- Uses standard TransformerBlock with modern options enabled:
use_rope=True: 2D Rotary Position Embeddinguse_qk_norm=True: Query-Key normalizationmlp_type='swiglu': Gated MLP activationuse_layer_scale=True: Residual scaling
Causal masking is applied via
attn_maskduring training.- forward(images: Tensor) NEPAOutput[source]#
Same as
torch.nn.Module.forward().- Parameters:
*args – Whatever you decide to pass into the forward method.
**kwargs – Keyword arguments are also possible.
- Returns:
Your model’s output
stable_pretraining.methods.salt module#
SALT: Static-teacher Asymmetric Latent Training.
SALT combines ideas from V-JEPA masking with MAE pixel reconstruction (Stage 1) and latent target prediction with a frozen teacher (Stage 2).
References
Li, Xianhang, et al. “Rethinking JEPA: Compute-Efficient Video SSL with Frozen Teachers.” 2025. https://arxiv.org/pdf/2509.24317
Example
from stable_pretraining.methods import SALT, MAE from stable_pretraining.backbone import MultiBlockMasking
# Stage 1: MAE with multi-block masking stage1 = MAE(“vit_tiny_patch16_224”, masking=MultiBlockMasking())
# Stage 2: SALT from Stage 1 checkpoint stage2 = SALT.from_checkpoint(
“stage1.ckpt”, encoder_name=”vit_tiny_patch16_224”, predictor_embed_dim=384, predictor_depth=12,
)
- class stable_pretraining.methods.salt.SALT(encoder_name: str = 'vit_tiny_patch16_224', predictor_embed_dim: int = 384, predictor_depth: int = 12, predictor_num_heads: int = 16, num_targets: int = 4, context_scale: Tuple[float, float] = (0.85, 1.0), target_scale: Tuple[float, float] = (0.15, 0.2), context_aspect_ratio: Tuple[float, float] = (1.0, 1.0), target_aspect_ratio: Tuple[float, float] = (0.75, 1.5), teacher_state_dict: dict = None, pretrained: bool = False)[source]#
Bases:
ModuleSALT Stage 2: Static-teacher Asymmetric Latent Training.
- Architecture:
Teacher (frozen): Encodes full unmasked image via EvalOnly(MaskedEncoder)
Student (trainable): Encodes only context (visible) patches
Predictor: Lightweight transformer predicting teacher latents at target positions
- Parameters:
encoder_name – timm model name (e.g., “vit_tiny_patch16_224”)
predictor_embed_dim – Predictor hidden dimension (default: 384)
predictor_depth – Number of predictor blocks (default: 12)
predictor_num_heads – Number of predictor attention heads (default: 16)
num_targets – Number of target blocks for masking (default: 4)
context_scale – (min, max) scale for context block
target_scale – (min, max) scale for each target block
context_aspect_ratio – (min, max) aspect ratio for context block
target_aspect_ratio – (min, max) aspect ratio for target blocks
teacher_state_dict – Optional state dict to load into teacher encoder
pretrained – Load pretrained encoder weights
Example:
model = SALT("vit_tiny_patch16_224") images = torch.randn(4, 3, 224, 224) model.train() output = model(images) output.loss.backward() model.eval() output = model(images) features = output.embedding # [B, D]
- forward(images: Tensor) SALTOutput[source]#
Forward pass.
Training: teacher encodes full image, student encodes context only, predictor predicts teacher latents at target positions, L1 loss.
Eval: student encodes full image, returns CLS token embedding, zero loss.
- Parameters:
images – Input images [B, C, H, W]
- Returns:
SALTOutput
- classmethod from_checkpoint(ckpt_path: str, encoder_name: str = 'vit_tiny_patch16_224', **kwargs) SALT[source]#
Create SALT Stage 2 from a Stage 1 (MAE/VPixel) checkpoint.
Loads the encoder weights from Stage 1 as the frozen teacher.
- Parameters:
ckpt_path – Path to Stage 1 checkpoint
encoder_name – timm model name matching Stage 1
kwargs – Additional arguments for SALT.__init__
- Returns:
SALT instance with teacher initialized from checkpoint
- class stable_pretraining.methods.salt.SALTOutput(loss: Tensor = None, embedding: Tensor = None, predictions: Tensor | None = None, targets: Tensor | None = None, num_targets: int = None, num_context: int = None)[source]#
Bases:
ModelOutputOutput from SALT forward pass.
- Variables:
loss – Prediction loss (L1 between predicted and teacher latents, 0 in eval)
embedding – CLS token embedding [B, D]
predictions – Predicted representations [B, N_tgt, D] (or None in eval)
targets – Teacher target representations [B, N_tgt, D] (or None in eval)
num_targets – Number of target patches (0 in eval)
num_context – Number of context patches (all patches in eval)
Module contents#
- class stable_pretraining.methods.IJEPA(model_or_model_name: str | Module = 'vit_base_patch16_224', predictor_embed_dim: int = 384, predictor_depth: int = 6, num_targets: int = 4, target_scale: Tuple[float, float] = (0.15, 0.2), target_aspect_ratio: Tuple[float, float] = (0.75, 1.5), context_scale: Tuple[float, float] = (0.85, 1.0), ema_decay_start: float = 0.996, ema_decay_end: float = 1.0, pretrained: bool = False)[source]#
Bases:
ModuleI-JEPA: Image-based Joint-Embedding Predictive Architecture.
- Architecture:
Context Encoder (student): Encodes visible/context patches
Target Encoder (teacher): EMA copy, encodes target patches
Predictor: Lightweight transformer predicting targets from context
The context encoder is wrapped with
TeacherStudentWrapper, enabling automatic EMA updates viaTeacherStudentCallback.- Parameters:
model_or_model_name – timm model name string or pre-instantiated nn.Module
predictor_embed_dim – Predictor hidden dimension (default: 384)
predictor_depth – Number of predictor blocks (default: 6)
num_targets – Number of target blocks to sample (default: 4)
target_scale – (min, max) fraction of patches per target block
target_aspect_ratio – (min, max) aspect ratio of target blocks
context_scale – (min, max) fraction of non-target patches as context
ema_decay_start – Initial EMA decay (default: 0.996)
ema_decay_end – Final EMA decay (default: 1.0)
pretrained – Load pretrained encoder weights
Example:
# Basic usage model = IJEPA("vit_base_patch16_224") images = torch.randn(4, 3, 224, 224) # Training mode: predicts masked targets model.train() output = model(images) output.loss.backward() # Eval mode: encodes all patches (no masking) model.eval() output = model(images) features = output.predictions # [B, N, D]
Example with Lightning:
import lightning as pl from stable_pretraining.callbacks import TeacherStudentCallback class IJEPALightning(pl.LightningModule): def __init__(self): super().__init__() self.model = IJEPA("vit_base_patch16_224") def training_step(self, batch, batch_idx): images = batch[0] if isinstance(batch, (list, tuple)) else batch output = self.model(images) self.log("loss", output.loss) return output.loss def configure_optimizers(self): return torch.optim.AdamW(self.parameters(), lr=1.5e-4) trainer = pl.Trainer(callbacks=[TeacherStudentCallback()]) trainer.fit(IJEPALightning(), dataloader)
Note
Use
TeacherStudentCallbackto handle EMA updates automaticallyIn eval mode,
num_targets=0and all patches are returned as contextAccess trained encoder via
model.encoder.student
- forward(images: Tensor, embedding_source: str = 'teacher') IJEPAOutput[source]#
Forward pass.
- In training mode:
Samples target blocks and context region via
IJEPAMaskingEncodes context through student, targets through teacher (EMA)
Predicts target representations from context
Returns smooth L1 loss between predictions and targets
- In eval mode:
No masking, all patches treated as context
Returns encoded features with zero loss
Always uses student encoder
- Parameters:
images – Input images [B, C, H, W]
embedding_source – Which encoder to use for the embedding output.
"teacher"(default) or"student". Only affects training mode; eval mode always uses student.
- Returns:
IJEPAOutputwith loss and representations
- class stable_pretraining.methods.LeJEPA(encoder_name: str = 'vit_base_patch16_224', projector: Module | None = None, n_slices: int = 1024, t_max: float = 3.0, n_points: int = 17, lamb: float = 0.02, pretrained: bool = False, drop_path_rate: float = 0.1)[source]#
Bases:
ModuleLeJEPA: multi-view invariance + sliced Epps-Pulley SIGReg.
- Architecture:
Backbone: timm ViT (CLS-pooled,
num_classes=0)Projector: MLP projection head
Loss:
invariance + (λ * SIGReg)
Centers are computed from global-view projections only. The invariance term penalises the MSE between each view’s projection and the center. The SIGReg term is a sliced goodness-of-fit test that pushes projected embeddings toward an isotropic Gaussian, averaged over views.
- Parameters:
encoder_name – timm model name (e.g.,
"vit_base_patch16_224")projector – Optional projection head. When
None, a 3-layer BN+ReLU MLP (embed_dim → 2048 → 2048 → 512) is created.n_slices – Random projection directions for the goodness-of-fit test (default: 1024)
t_max – EP integration upper bound (default: 3.0)
n_points – EP quadrature nodes (default: 17)
lamb – SIGReg weight λ (default: 0.02)
pretrained – Load pretrained timm weights
Example:
model = LeJEPA("vit_base_patch16_224") images = torch.randn(4, 3, 224, 224) model.train() output = model( global_views=[images, images], all_views=[images, images, images, images], ) output.loss.backward() model.eval() output = model(images=images) features = output.embedding # [4, 768]
Example with Lightning:
import lightning as pl from stable_pretraining.methods.lejepa import LeJEPA class LeJEPALightning(pl.LightningModule): def __init__(self): super().__init__() self.model = LeJEPA("vit_base_patch16_224") def training_step(self, batch, batch_idx): views = [v["image"] for v in batch["views"]] output = self.model(global_views=views, all_views=views) self.log("loss", output.loss) return output.loss def configure_optimizers(self): return torch.optim.AdamW(self.parameters(), lr=1e-3)
- forward(global_views: list[Tensor] | None = None, local_views: list[Tensor] | None = None, images: Tensor | None = None) LeJEPAOutput[source]#
Same as
torch.nn.Module.forward().- Parameters:
*args – Whatever you decide to pass into the forward method.
**kwargs – Keyword arguments are also possible.
- Returns:
Your model’s output
- class stable_pretraining.methods.MAE(model_or_model_name: str | Module = 'vit_base_patch16_224', decoder_embed_dim: int = 512, decoder_depth: int = 8, decoder_num_heads: int = 16, mask_ratio: float = 0.75, block_size: int = 1, norm_pix_loss: bool = True, loss_type: str = 'mse', pretrained: bool = False, masking: Module | None = None)[source]#
Bases:
ModuleMAE: Masked Autoencoders Are Scalable Vision Learners.
- Architecture:
Encoder: ViT processing only visible (unmasked) patches
Decoder: Lightweight transformer reconstructing masked patches
Target: Normalized pixel values of masked patches
- Parameters:
model_or_model_name – timm model name string or pre-instantiated nn.Module
decoder_embed_dim – Decoder hidden dimension (default: 512)
decoder_depth – Number of decoder blocks (default: 8)
decoder_num_heads – Decoder attention heads (default: 16)
mask_ratio – Fraction of patches to mask (default: 0.75)
block_size – Masking block size, 1=random (default: 1)
norm_pix_loss – Normalize target pixels per patch (default: True)
loss_type – Loss type for MAELoss (default: ‘mse’)
pretrained – Load pretrained encoder weights
masking – Custom masking module (e.g., MultiBlockMasking). When provided, overrides mask_ratio and block_size.
Example:
# Basic usage model = MAE("vit_base_patch16_224", mask_ratio=0.75) images = torch.randn(4, 3, 224, 224) model.train() output = model(images) output.loss.backward() model.eval() output = model(images) # Full reconstruction, zero loss
Example with Lightning:
class MAELightning(pl.LightningModule): def __init__(self): super().__init__() self.model = MAE("vit_base_patch16_224") def training_step(self, batch, batch_idx): images = batch[0] if isinstance(batch, (list, tuple)) else batch return self.model(images).loss def configure_optimizers(self): return torch.optim.AdamW(self.parameters(), lr=1.5e-4)
- class stable_pretraining.methods.NEPA(img_size: int = 224, patch_size: int = 14, in_chans: int = 3, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4.0, drop_rate: float = 0.0, attn_drop_rate: float = 0.0, drop_path_rate: float = 0.0, use_rope: bool = True, use_qk_norm: bool = True, use_swiglu: bool = True, layer_scale_init: float = 1e-05)[source]#
Bases:
ModuleNEPA: Next-Embedding Predictive Autoregression.
- Uses standard TransformerBlock with modern options enabled:
use_rope=True: 2D Rotary Position Embeddinguse_qk_norm=True: Query-Key normalizationmlp_type='swiglu': Gated MLP activationuse_layer_scale=True: Residual scaling
Causal masking is applied via
attn_maskduring training.- forward(images: Tensor) NEPAOutput[source]#
Same as
torch.nn.Module.forward().- Parameters:
*args – Whatever you decide to pass into the forward method.
**kwargs – Keyword arguments are also possible.
- Returns:
Your model’s output
- class stable_pretraining.methods.SALT(encoder_name: str = 'vit_tiny_patch16_224', predictor_embed_dim: int = 384, predictor_depth: int = 12, predictor_num_heads: int = 16, num_targets: int = 4, context_scale: Tuple[float, float] = (0.85, 1.0), target_scale: Tuple[float, float] = (0.15, 0.2), context_aspect_ratio: Tuple[float, float] = (1.0, 1.0), target_aspect_ratio: Tuple[float, float] = (0.75, 1.5), teacher_state_dict: dict = None, pretrained: bool = False)[source]#
Bases:
ModuleSALT Stage 2: Static-teacher Asymmetric Latent Training.
- Architecture:
Teacher (frozen): Encodes full unmasked image via EvalOnly(MaskedEncoder)
Student (trainable): Encodes only context (visible) patches
Predictor: Lightweight transformer predicting teacher latents at target positions
- Parameters:
encoder_name – timm model name (e.g., “vit_tiny_patch16_224”)
predictor_embed_dim – Predictor hidden dimension (default: 384)
predictor_depth – Number of predictor blocks (default: 12)
predictor_num_heads – Number of predictor attention heads (default: 16)
num_targets – Number of target blocks for masking (default: 4)
context_scale – (min, max) scale for context block
target_scale – (min, max) scale for each target block
context_aspect_ratio – (min, max) aspect ratio for context block
target_aspect_ratio – (min, max) aspect ratio for target blocks
teacher_state_dict – Optional state dict to load into teacher encoder
pretrained – Load pretrained encoder weights
Example:
model = SALT("vit_tiny_patch16_224") images = torch.randn(4, 3, 224, 224) model.train() output = model(images) output.loss.backward() model.eval() output = model(images) features = output.embedding # [B, D]
- forward(images: Tensor) SALTOutput[source]#
Forward pass.
Training: teacher encodes full image, student encodes context only, predictor predicts teacher latents at target positions, L1 loss.
Eval: student encodes full image, returns CLS token embedding, zero loss.
- Parameters:
images – Input images [B, C, H, W]
- Returns:
SALTOutput
- classmethod from_checkpoint(ckpt_path: str, encoder_name: str = 'vit_tiny_patch16_224', **kwargs) SALT[source]#
Create SALT Stage 2 from a Stage 1 (MAE/VPixel) checkpoint.
Loads the encoder weights from Stage 1 as the frozen teacher.
- Parameters:
ckpt_path – Path to Stage 1 checkpoint
encoder_name – timm model name matching Stage 1
kwargs – Additional arguments for SALT.__init__
- Returns:
SALT instance with teacher initialized from checkpoint