stable_pretraining.methods package#

Submodules#

stable_pretraining.methods.ijepa module#

I-JEPA: Image-based Joint-Embedding Predictive Architecture.

Self-supervised learning via predicting target patch representations from context patch representations using a lightweight predictor.

References

Assran et al. “Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture.” CVPR 2023. https://arxiv.org/abs/2301.08243

Example:

from stable_pretraining.backbone import IJEPA
from stable_pretraining.callbacks import TeacherStudentCallback
import lightning as pl

# Create model
model = IJEPA(
    model_or_model_name="vit_base_patch16_224",
    predictor_embed_dim=384,
    predictor_depth=6,
    num_targets=4,
)

# Training with PyTorch Lightning
trainer = pl.Trainer(
    max_epochs=300,
    callbacks=[TeacherStudentCallback()],  # Handles EMA updates
)
trainer.fit(model, dataloader)

# Get encoder for downstream tasks
encoder = model.encoder.student
class stable_pretraining.methods.ijepa.IJEPA(model_or_model_name: str | Module = 'vit_base_patch16_224', predictor_embed_dim: int = 384, predictor_depth: int = 6, num_targets: int = 4, target_scale: Tuple[float, float] = (0.15, 0.2), target_aspect_ratio: Tuple[float, float] = (0.75, 1.5), context_scale: Tuple[float, float] = (0.85, 1.0), ema_decay_start: float = 0.996, ema_decay_end: float = 1.0, pretrained: bool = False)[source]#

Bases: Module

I-JEPA: Image-based Joint-Embedding Predictive Architecture.

Architecture:
  • Context Encoder (student): Encodes visible/context patches

  • Target Encoder (teacher): EMA copy, encodes target patches

  • Predictor: Lightweight transformer predicting targets from context

The context encoder is wrapped with TeacherStudentWrapper, enabling automatic EMA updates via TeacherStudentCallback.

Parameters:
  • model_or_model_name – timm model name string or pre-instantiated nn.Module

  • predictor_embed_dim – Predictor hidden dimension (default: 384)

  • predictor_depth – Number of predictor blocks (default: 6)

  • num_targets – Number of target blocks to sample (default: 4)

  • target_scale – (min, max) fraction of patches per target block

  • target_aspect_ratio – (min, max) aspect ratio of target blocks

  • context_scale – (min, max) fraction of non-target patches as context

  • ema_decay_start – Initial EMA decay (default: 0.996)

  • ema_decay_end – Final EMA decay (default: 1.0)

  • pretrained – Load pretrained encoder weights

Example:

# Basic usage
model = IJEPA("vit_base_patch16_224")
images = torch.randn(4, 3, 224, 224)

# Training mode: predicts masked targets
model.train()
output = model(images)
output.loss.backward()

# Eval mode: encodes all patches (no masking)
model.eval()
output = model(images)
features = output.predictions  # [B, N, D]

Example with Lightning:

import lightning as pl
from stable_pretraining.callbacks import TeacherStudentCallback


class IJEPALightning(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = IJEPA("vit_base_patch16_224")

    def training_step(self, batch, batch_idx):
        images = batch[0] if isinstance(batch, (list, tuple)) else batch
        output = self.model(images)
        self.log("loss", output.loss)
        return output.loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=1.5e-4)


trainer = pl.Trainer(callbacks=[TeacherStudentCallback()])
trainer.fit(IJEPALightning(), dataloader)

Note

  • Use TeacherStudentCallback to handle EMA updates automatically

  • In eval mode, num_targets=0 and all patches are returned as context

  • Access trained encoder via model.encoder.student

forward(images: Tensor, embedding_source: str = 'teacher') IJEPAOutput[source]#

Forward pass.

In training mode:
  • Samples target blocks and context region via IJEPAMasking

  • Encodes context through student, targets through teacher (EMA)

  • Predicts target representations from context

  • Returns smooth L1 loss between predictions and targets

In eval mode:
  • No masking, all patches treated as context

  • Returns encoded features with zero loss

  • Always uses student encoder

Parameters:
  • images – Input images [B, C, H, W]

  • embedding_source – Which encoder to use for the embedding output. "teacher" (default) or "student". Only affects training mode; eval mode always uses student.

Returns:

IJEPAOutput with loss and representations

class stable_pretraining.methods.ijepa.IJEPAOutput(loss: Tensor = None, embedding: Tensor = None, predictions: Tensor = None, targets: Tensor = None, num_targets: int = None, num_context: int = None)[source]#

Bases: ModelOutput

Output from IJEPA forward pass.

Variables:
  • loss – Prediction loss (0 in eval mode)

  • embedding – Patch embeddings [B, N, D] for downstream use

  • predictions – Predicted representations [B, N_tgt, D] (or context in eval)

  • targets – Target representations [B, N_tgt, D] (or context in eval)

  • num_targets – Number of target patches (0 in eval)

  • num_context – Number of context patches (all patches in eval)

embedding: Tensor = None#
loss: Tensor = None#
num_context: int = None#
num_targets: int = None#
predictions: Tensor = None#
targets: Tensor = None#

stable_pretraining.methods.lejepa module#

LeJEPA: Latent Embedding Joint-Embedding Predictive Architecture.

Self-supervised learning via multi-view invariance combined with a sliced goodness-of-fit test (SIGReg) that pushes embeddings toward an isotropic Gaussian.

References

Balestriero & LeCun. “LeJEPA: Provable and Scalable Self-Supervised Learning Without the Heuristics.” 2025. https://arxiv.org/abs/2511.08544

Example:

from stable_pretraining.methods.lejepa import LeJEPA

model = LeJEPA("vit_small_patch16_224")

global_images = [torch.randn(4, 3, 224, 224)] * 2
all_images = [torch.randn(4, 3, 224, 224)] * 6
model.train()
output = model(global_images, all_images)
output.loss.backward()

model.eval()
output = model(images=torch.randn(4, 3, 224, 224))
features = output.embedding  # [N, D]
class stable_pretraining.methods.lejepa.EppsPulley(t_max: float = 3.0, n_points: int = 17)[source]#

Bases: Module

Epps-Pulley goodness-of-fit test for univariate normality.

Projects data onto a grid of points and computes the Epps-Pulley statistic.

Parameters:
  • t_max – Integration upper bound.

  • n_points – Number of integration points.

forward(x: Tensor) Tensor[source]#
Parameters:

x – Samples [N, S] (N samples, S slices).

Returns:

Per-slice statistic [S].

class stable_pretraining.methods.lejepa.LeJEPA(encoder_name: str = 'vit_base_patch16_224', projector: Module | None = None, n_slices: int = 1024, t_max: float = 3.0, n_points: int = 17, lamb: float = 0.02, pretrained: bool = False, drop_path_rate: float = 0.1)[source]#

Bases: Module

LeJEPA: multi-view invariance + sliced Epps-Pulley SIGReg.

Architecture:
  • Backbone: timm ViT (CLS-pooled, num_classes=0)

  • Projector: MLP projection head

  • Loss: invariance + * SIGReg)

Centers are computed from global-view projections only. The invariance term penalises the MSE between each view’s projection and the center. The SIGReg term is a sliced goodness-of-fit test that pushes projected embeddings toward an isotropic Gaussian, averaged over views.

Parameters:
  • encoder_name – timm model name (e.g., "vit_base_patch16_224")

  • projector – Optional projection head. When None, a 3-layer BN+ReLU MLP (embed_dim 2048 2048 512) is created.

  • n_slices – Random projection directions for the goodness-of-fit test (default: 1024)

  • t_max – EP integration upper bound (default: 3.0)

  • n_points – EP quadrature nodes (default: 17)

  • lamb – SIGReg weight λ (default: 0.02)

  • pretrained – Load pretrained timm weights

Example:

model = LeJEPA("vit_base_patch16_224")
images = torch.randn(4, 3, 224, 224)

model.train()
output = model(
    global_views=[images, images],
    all_views=[images, images, images, images],
)
output.loss.backward()

model.eval()
output = model(images=images)
features = output.embedding  # [4, 768]

Example with Lightning:

import lightning as pl
from stable_pretraining.methods.lejepa import LeJEPA


class LeJEPALightning(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = LeJEPA("vit_base_patch16_224")

    def training_step(self, batch, batch_idx):
        views = [v["image"] for v in batch["views"]]
        output = self.model(global_views=views, all_views=views)
        self.log("loss", output.loss)
        return output.loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=1e-3)
forward(global_views: list[Tensor] | None = None, local_views: list[Tensor] | None = None, images: Tensor | None = None) LeJEPAOutput[source]#

Same as torch.nn.Module.forward().

Parameters:
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

Returns:

Your model’s output

class stable_pretraining.methods.lejepa.LeJEPAOutput(loss: Tensor = None, embedding: Tensor = None, inv_loss: Tensor = None, sigreg_loss: Tensor = None)[source]#

Bases: ModelOutput

Output from LeJEPA forward pass.

Variables:
  • loss – Combined invariance + SIGReg loss (0 in eval mode).

  • embedding – Backbone embeddings [V*N, D] (train) or [N, D] (eval).

  • inv_loss – Invariance component.

  • sigreg_loss – Epps-Pulley goodness-of-fit component.

embedding: Tensor = None#
inv_loss: Tensor = None#
loss: Tensor = None#
sigreg_loss: Tensor = None#
class stable_pretraining.methods.lejepa.SlicedEppsPulley(num_slices: int = 1024, t_max: float = 3.0, n_points: int = 17)[source]#

Bases: Module

Sliced Epps-Pulley goodness-of-fit test for multivariate normality.

Projects data onto random 1-D directions and averages the univariate Epps-Pulley statistics. A synchronised step counter seeds the random projections so all DDP ranks sample identical directions.

Parameters:
  • num_slices – Number of random 1-D projections.

  • t_max – EP integration upper bound.

  • n_points – EP quadrature nodes.

forward(x: Tensor) Tensor[source]#
Parameters:

x – Embeddings [N, D].

Returns:

Scalar mean EP statistic.

stable_pretraining.methods.mae module#

MAE: Masked Autoencoders Are Scalable Vision Learners.

Self-supervised learning via reconstructing masked patches from visible patches.

References

He et al. “Masked Autoencoders Are Scalable Vision Learners.” CVPR 2022. https://arxiv.org/abs/2111.06377

Example:

from stable_pretraining.backbone import MAE
import lightning as pl

# Create model
model = MAE("vit_base_patch16_224", mask_ratio=0.75)

# Training
model.train()
output = model(images)
output.loss.backward()

# Get encoder for downstream
encoder = model.encoder
class stable_pretraining.methods.mae.MAE(model_or_model_name: str | Module = 'vit_base_patch16_224', decoder_embed_dim: int = 512, decoder_depth: int = 8, decoder_num_heads: int = 16, mask_ratio: float = 0.75, block_size: int = 1, norm_pix_loss: bool = True, loss_type: str = 'mse', pretrained: bool = False, masking: Module | None = None)[source]#

Bases: Module

MAE: Masked Autoencoders Are Scalable Vision Learners.

Architecture:
  • Encoder: ViT processing only visible (unmasked) patches

  • Decoder: Lightweight transformer reconstructing masked patches

  • Target: Normalized pixel values of masked patches

Parameters:
  • model_or_model_name – timm model name string or pre-instantiated nn.Module

  • decoder_embed_dim – Decoder hidden dimension (default: 512)

  • decoder_depth – Number of decoder blocks (default: 8)

  • decoder_num_heads – Decoder attention heads (default: 16)

  • mask_ratio – Fraction of patches to mask (default: 0.75)

  • block_size – Masking block size, 1=random (default: 1)

  • norm_pix_loss – Normalize target pixels per patch (default: True)

  • loss_type – Loss type for MAELoss (default: ‘mse’)

  • pretrained – Load pretrained encoder weights

  • masking – Custom masking module (e.g., MultiBlockMasking). When provided, overrides mask_ratio and block_size.

Example:

# Basic usage
model = MAE("vit_base_patch16_224", mask_ratio=0.75)
images = torch.randn(4, 3, 224, 224)

model.train()
output = model(images)
output.loss.backward()

model.eval()
output = model(images)  # Full reconstruction, zero loss

Example with Lightning:

class MAELightning(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = MAE("vit_base_patch16_224")

    def training_step(self, batch, batch_idx):
        images = batch[0] if isinstance(batch, (list, tuple)) else batch
        return self.model(images).loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=1.5e-4)
forward(images: Tensor) MAEOutput[source]#

Forward pass.

Training: masks patches, encodes visible, decodes all, loss on masked. Eval: no masking, full encode/decode, zero loss.

Parameters:

images – Input images [B, C, H, W]

Returns:

MAEOutput with loss and reconstructions

class stable_pretraining.methods.mae.MAEOutput(loss: Tensor = None, predictions: Tensor = None, mask: Tensor = None, num_masked: int = None, num_visible: int = None)[source]#

Bases: ModelOutput

Output from MAE forward pass.

Variables:
  • loss – Reconstruction loss (MSE on masked patches)

  • predictions – Reconstructed patches [B, N, patch_dim]

  • mask – Binary mask where 1=masked, 0=visible [B, N]

  • num_masked – Number of masked patches

  • num_visible – Number of visible patches

loss: Tensor = None#
mask: Tensor = None#
num_masked: int = None#
num_visible: int = None#
predictions: Tensor = None#

stable_pretraining.methods.nepa module#

NEPA: Next-Embedding Predictive Autoregression.

class stable_pretraining.methods.nepa.NEPA(img_size: int = 224, patch_size: int = 14, in_chans: int = 3, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4.0, drop_rate: float = 0.0, attn_drop_rate: float = 0.0, drop_path_rate: float = 0.0, use_rope: bool = True, use_qk_norm: bool = True, use_swiglu: bool = True, layer_scale_init: float = 1e-05)[source]#

Bases: Module

NEPA: Next-Embedding Predictive Autoregression.

Uses standard TransformerBlock with modern options enabled:
  • use_rope=True: 2D Rotary Position Embedding

  • use_qk_norm=True: Query-Key normalization

  • mlp_type='swiglu': Gated MLP activation

  • use_layer_scale=True: Residual scaling

Causal masking is applied via attn_mask during training.

forward(images: Tensor) NEPAOutput[source]#

Same as torch.nn.Module.forward().

Parameters:
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

Returns:

Your model’s output

forward_features(images: Tensor, causal: bool = False) Tensor[source]#
freeze_patch_embed()[source]#
get_classifier_features(images: Tensor) Tensor[source]#
get_dense_features(images: Tensor) Tensor[source]#
class stable_pretraining.methods.nepa.NEPAOutput(loss: Tensor = None, embeddings: Tensor = None, grid_size: Tuple[int, int] = None)[source]#

Bases: ModelOutput

Docstring for NEPAOutput.

embeddings: Tensor = None#
grid_size: Tuple[int, int] = None#
loss: Tensor = None#
stable_pretraining.methods.nepa.nepa_base_patch14(**kwargs) NEPA[source]#
stable_pretraining.methods.nepa.nepa_large_patch14(**kwargs) NEPA[source]#

stable_pretraining.methods.salt module#

SALT: Static-teacher Asymmetric Latent Training.

SALT combines ideas from V-JEPA masking with MAE pixel reconstruction (Stage 1) and latent target prediction with a frozen teacher (Stage 2).

References

Li, Xianhang, et al. “Rethinking JEPA: Compute-Efficient Video SSL with Frozen Teachers.” 2025. https://arxiv.org/pdf/2509.24317

Example

from stable_pretraining.methods import SALT, MAE from stable_pretraining.backbone import MultiBlockMasking

# Stage 1: MAE with multi-block masking stage1 = MAE(“vit_tiny_patch16_224”, masking=MultiBlockMasking())

# Stage 2: SALT from Stage 1 checkpoint stage2 = SALT.from_checkpoint(

“stage1.ckpt”, encoder_name=”vit_tiny_patch16_224”, predictor_embed_dim=384, predictor_depth=12,

)

class stable_pretraining.methods.salt.SALT(encoder_name: str = 'vit_tiny_patch16_224', predictor_embed_dim: int = 384, predictor_depth: int = 12, predictor_num_heads: int = 16, num_targets: int = 4, context_scale: Tuple[float, float] = (0.85, 1.0), target_scale: Tuple[float, float] = (0.15, 0.2), context_aspect_ratio: Tuple[float, float] = (1.0, 1.0), target_aspect_ratio: Tuple[float, float] = (0.75, 1.5), teacher_state_dict: dict = None, pretrained: bool = False)[source]#

Bases: Module

SALT Stage 2: Static-teacher Asymmetric Latent Training.

Architecture:
  • Teacher (frozen): Encodes full unmasked image via EvalOnly(MaskedEncoder)

  • Student (trainable): Encodes only context (visible) patches

  • Predictor: Lightweight transformer predicting teacher latents at target positions

Parameters:
  • encoder_name – timm model name (e.g., “vit_tiny_patch16_224”)

  • predictor_embed_dim – Predictor hidden dimension (default: 384)

  • predictor_depth – Number of predictor blocks (default: 12)

  • predictor_num_heads – Number of predictor attention heads (default: 16)

  • num_targets – Number of target blocks for masking (default: 4)

  • context_scale – (min, max) scale for context block

  • target_scale – (min, max) scale for each target block

  • context_aspect_ratio – (min, max) aspect ratio for context block

  • target_aspect_ratio – (min, max) aspect ratio for target blocks

  • teacher_state_dict – Optional state dict to load into teacher encoder

  • pretrained – Load pretrained encoder weights

Example:

model = SALT("vit_tiny_patch16_224")
images = torch.randn(4, 3, 224, 224)

model.train()
output = model(images)
output.loss.backward()

model.eval()
output = model(images)
features = output.embedding  # [B, D]
forward(images: Tensor) SALTOutput[source]#

Forward pass.

Training: teacher encodes full image, student encodes context only, predictor predicts teacher latents at target positions, L1 loss.

Eval: student encodes full image, returns CLS token embedding, zero loss.

Parameters:

images – Input images [B, C, H, W]

Returns:

SALTOutput

classmethod from_checkpoint(ckpt_path: str, encoder_name: str = 'vit_tiny_patch16_224', **kwargs) SALT[source]#

Create SALT Stage 2 from a Stage 1 (MAE/VPixel) checkpoint.

Loads the encoder weights from Stage 1 as the frozen teacher.

Parameters:
  • ckpt_path – Path to Stage 1 checkpoint

  • encoder_name – timm model name matching Stage 1

  • kwargs – Additional arguments for SALT.__init__

Returns:

SALT instance with teacher initialized from checkpoint

class stable_pretraining.methods.salt.SALTOutput(loss: Tensor = None, embedding: Tensor = None, predictions: Tensor | None = None, targets: Tensor | None = None, num_targets: int = None, num_context: int = None)[source]#

Bases: ModelOutput

Output from SALT forward pass.

Variables:
  • loss – Prediction loss (L1 between predicted and teacher latents, 0 in eval)

  • embedding – CLS token embedding [B, D]

  • predictions – Predicted representations [B, N_tgt, D] (or None in eval)

  • targets – Teacher target representations [B, N_tgt, D] (or None in eval)

  • num_targets – Number of target patches (0 in eval)

  • num_context – Number of context patches (all patches in eval)

embedding: Tensor = None#
loss: Tensor = None#
num_context: int = None#
num_targets: int = None#
predictions: Tensor | None = None#
targets: Tensor | None = None#

Module contents#

class stable_pretraining.methods.IJEPA(model_or_model_name: str | Module = 'vit_base_patch16_224', predictor_embed_dim: int = 384, predictor_depth: int = 6, num_targets: int = 4, target_scale: Tuple[float, float] = (0.15, 0.2), target_aspect_ratio: Tuple[float, float] = (0.75, 1.5), context_scale: Tuple[float, float] = (0.85, 1.0), ema_decay_start: float = 0.996, ema_decay_end: float = 1.0, pretrained: bool = False)[source]#

Bases: Module

I-JEPA: Image-based Joint-Embedding Predictive Architecture.

Architecture:
  • Context Encoder (student): Encodes visible/context patches

  • Target Encoder (teacher): EMA copy, encodes target patches

  • Predictor: Lightweight transformer predicting targets from context

The context encoder is wrapped with TeacherStudentWrapper, enabling automatic EMA updates via TeacherStudentCallback.

Parameters:
  • model_or_model_name – timm model name string or pre-instantiated nn.Module

  • predictor_embed_dim – Predictor hidden dimension (default: 384)

  • predictor_depth – Number of predictor blocks (default: 6)

  • num_targets – Number of target blocks to sample (default: 4)

  • target_scale – (min, max) fraction of patches per target block

  • target_aspect_ratio – (min, max) aspect ratio of target blocks

  • context_scale – (min, max) fraction of non-target patches as context

  • ema_decay_start – Initial EMA decay (default: 0.996)

  • ema_decay_end – Final EMA decay (default: 1.0)

  • pretrained – Load pretrained encoder weights

Example:

# Basic usage
model = IJEPA("vit_base_patch16_224")
images = torch.randn(4, 3, 224, 224)

# Training mode: predicts masked targets
model.train()
output = model(images)
output.loss.backward()

# Eval mode: encodes all patches (no masking)
model.eval()
output = model(images)
features = output.predictions  # [B, N, D]

Example with Lightning:

import lightning as pl
from stable_pretraining.callbacks import TeacherStudentCallback


class IJEPALightning(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = IJEPA("vit_base_patch16_224")

    def training_step(self, batch, batch_idx):
        images = batch[0] if isinstance(batch, (list, tuple)) else batch
        output = self.model(images)
        self.log("loss", output.loss)
        return output.loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=1.5e-4)


trainer = pl.Trainer(callbacks=[TeacherStudentCallback()])
trainer.fit(IJEPALightning(), dataloader)

Note

  • Use TeacherStudentCallback to handle EMA updates automatically

  • In eval mode, num_targets=0 and all patches are returned as context

  • Access trained encoder via model.encoder.student

forward(images: Tensor, embedding_source: str = 'teacher') IJEPAOutput[source]#

Forward pass.

In training mode:
  • Samples target blocks and context region via IJEPAMasking

  • Encodes context through student, targets through teacher (EMA)

  • Predicts target representations from context

  • Returns smooth L1 loss between predictions and targets

In eval mode:
  • No masking, all patches treated as context

  • Returns encoded features with zero loss

  • Always uses student encoder

Parameters:
  • images – Input images [B, C, H, W]

  • embedding_source – Which encoder to use for the embedding output. "teacher" (default) or "student". Only affects training mode; eval mode always uses student.

Returns:

IJEPAOutput with loss and representations

class stable_pretraining.methods.LeJEPA(encoder_name: str = 'vit_base_patch16_224', projector: Module | None = None, n_slices: int = 1024, t_max: float = 3.0, n_points: int = 17, lamb: float = 0.02, pretrained: bool = False, drop_path_rate: float = 0.1)[source]#

Bases: Module

LeJEPA: multi-view invariance + sliced Epps-Pulley SIGReg.

Architecture:
  • Backbone: timm ViT (CLS-pooled, num_classes=0)

  • Projector: MLP projection head

  • Loss: invariance + * SIGReg)

Centers are computed from global-view projections only. The invariance term penalises the MSE between each view’s projection and the center. The SIGReg term is a sliced goodness-of-fit test that pushes projected embeddings toward an isotropic Gaussian, averaged over views.

Parameters:
  • encoder_name – timm model name (e.g., "vit_base_patch16_224")

  • projector – Optional projection head. When None, a 3-layer BN+ReLU MLP (embed_dim 2048 2048 512) is created.

  • n_slices – Random projection directions for the goodness-of-fit test (default: 1024)

  • t_max – EP integration upper bound (default: 3.0)

  • n_points – EP quadrature nodes (default: 17)

  • lamb – SIGReg weight λ (default: 0.02)

  • pretrained – Load pretrained timm weights

Example:

model = LeJEPA("vit_base_patch16_224")
images = torch.randn(4, 3, 224, 224)

model.train()
output = model(
    global_views=[images, images],
    all_views=[images, images, images, images],
)
output.loss.backward()

model.eval()
output = model(images=images)
features = output.embedding  # [4, 768]

Example with Lightning:

import lightning as pl
from stable_pretraining.methods.lejepa import LeJEPA


class LeJEPALightning(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = LeJEPA("vit_base_patch16_224")

    def training_step(self, batch, batch_idx):
        views = [v["image"] for v in batch["views"]]
        output = self.model(global_views=views, all_views=views)
        self.log("loss", output.loss)
        return output.loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=1e-3)
forward(global_views: list[Tensor] | None = None, local_views: list[Tensor] | None = None, images: Tensor | None = None) LeJEPAOutput[source]#

Same as torch.nn.Module.forward().

Parameters:
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

Returns:

Your model’s output

class stable_pretraining.methods.MAE(model_or_model_name: str | Module = 'vit_base_patch16_224', decoder_embed_dim: int = 512, decoder_depth: int = 8, decoder_num_heads: int = 16, mask_ratio: float = 0.75, block_size: int = 1, norm_pix_loss: bool = True, loss_type: str = 'mse', pretrained: bool = False, masking: Module | None = None)[source]#

Bases: Module

MAE: Masked Autoencoders Are Scalable Vision Learners.

Architecture:
  • Encoder: ViT processing only visible (unmasked) patches

  • Decoder: Lightweight transformer reconstructing masked patches

  • Target: Normalized pixel values of masked patches

Parameters:
  • model_or_model_name – timm model name string or pre-instantiated nn.Module

  • decoder_embed_dim – Decoder hidden dimension (default: 512)

  • decoder_depth – Number of decoder blocks (default: 8)

  • decoder_num_heads – Decoder attention heads (default: 16)

  • mask_ratio – Fraction of patches to mask (default: 0.75)

  • block_size – Masking block size, 1=random (default: 1)

  • norm_pix_loss – Normalize target pixels per patch (default: True)

  • loss_type – Loss type for MAELoss (default: ‘mse’)

  • pretrained – Load pretrained encoder weights

  • masking – Custom masking module (e.g., MultiBlockMasking). When provided, overrides mask_ratio and block_size.

Example:

# Basic usage
model = MAE("vit_base_patch16_224", mask_ratio=0.75)
images = torch.randn(4, 3, 224, 224)

model.train()
output = model(images)
output.loss.backward()

model.eval()
output = model(images)  # Full reconstruction, zero loss

Example with Lightning:

class MAELightning(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = MAE("vit_base_patch16_224")

    def training_step(self, batch, batch_idx):
        images = batch[0] if isinstance(batch, (list, tuple)) else batch
        return self.model(images).loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=1.5e-4)
forward(images: Tensor) MAEOutput[source]#

Forward pass.

Training: masks patches, encodes visible, decodes all, loss on masked. Eval: no masking, full encode/decode, zero loss.

Parameters:

images – Input images [B, C, H, W]

Returns:

MAEOutput with loss and reconstructions

class stable_pretraining.methods.NEPA(img_size: int = 224, patch_size: int = 14, in_chans: int = 3, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4.0, drop_rate: float = 0.0, attn_drop_rate: float = 0.0, drop_path_rate: float = 0.0, use_rope: bool = True, use_qk_norm: bool = True, use_swiglu: bool = True, layer_scale_init: float = 1e-05)[source]#

Bases: Module

NEPA: Next-Embedding Predictive Autoregression.

Uses standard TransformerBlock with modern options enabled:
  • use_rope=True: 2D Rotary Position Embedding

  • use_qk_norm=True: Query-Key normalization

  • mlp_type='swiglu': Gated MLP activation

  • use_layer_scale=True: Residual scaling

Causal masking is applied via attn_mask during training.

forward(images: Tensor) NEPAOutput[source]#

Same as torch.nn.Module.forward().

Parameters:
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

Returns:

Your model’s output

forward_features(images: Tensor, causal: bool = False) Tensor[source]#
freeze_patch_embed()[source]#
get_classifier_features(images: Tensor) Tensor[source]#
get_dense_features(images: Tensor) Tensor[source]#
class stable_pretraining.methods.SALT(encoder_name: str = 'vit_tiny_patch16_224', predictor_embed_dim: int = 384, predictor_depth: int = 12, predictor_num_heads: int = 16, num_targets: int = 4, context_scale: Tuple[float, float] = (0.85, 1.0), target_scale: Tuple[float, float] = (0.15, 0.2), context_aspect_ratio: Tuple[float, float] = (1.0, 1.0), target_aspect_ratio: Tuple[float, float] = (0.75, 1.5), teacher_state_dict: dict = None, pretrained: bool = False)[source]#

Bases: Module

SALT Stage 2: Static-teacher Asymmetric Latent Training.

Architecture:
  • Teacher (frozen): Encodes full unmasked image via EvalOnly(MaskedEncoder)

  • Student (trainable): Encodes only context (visible) patches

  • Predictor: Lightweight transformer predicting teacher latents at target positions

Parameters:
  • encoder_name – timm model name (e.g., “vit_tiny_patch16_224”)

  • predictor_embed_dim – Predictor hidden dimension (default: 384)

  • predictor_depth – Number of predictor blocks (default: 12)

  • predictor_num_heads – Number of predictor attention heads (default: 16)

  • num_targets – Number of target blocks for masking (default: 4)

  • context_scale – (min, max) scale for context block

  • target_scale – (min, max) scale for each target block

  • context_aspect_ratio – (min, max) aspect ratio for context block

  • target_aspect_ratio – (min, max) aspect ratio for target blocks

  • teacher_state_dict – Optional state dict to load into teacher encoder

  • pretrained – Load pretrained encoder weights

Example:

model = SALT("vit_tiny_patch16_224")
images = torch.randn(4, 3, 224, 224)

model.train()
output = model(images)
output.loss.backward()

model.eval()
output = model(images)
features = output.embedding  # [B, D]
forward(images: Tensor) SALTOutput[source]#

Forward pass.

Training: teacher encodes full image, student encodes context only, predictor predicts teacher latents at target positions, L1 loss.

Eval: student encodes full image, returns CLS token embedding, zero loss.

Parameters:

images – Input images [B, C, H, W]

Returns:

SALTOutput

classmethod from_checkpoint(ckpt_path: str, encoder_name: str = 'vit_tiny_patch16_224', **kwargs) SALT[source]#

Create SALT Stage 2 from a Stage 1 (MAE/VPixel) checkpoint.

Loads the encoder weights from Stage 1 as the frozen teacher.

Parameters:
  • ckpt_path – Path to Stage 1 checkpoint

  • encoder_name – timm model name matching Stage 1

  • kwargs – Additional arguments for SALT.__init__

Returns:

SALT instance with teacher initialized from checkpoint