LeJEPA#
- class stable_pretraining.methods.LeJEPA(encoder_name: str = 'vit_base_patch16_224', projector: Module | None = None, n_slices: int = 1024, t_max: float = 3.0, n_points: int = 17, lamb: float = 0.02, pretrained: bool = False, drop_path_rate: float = 0.1)[source]#
Bases:
ModuleLeJEPA: multi-view invariance + sliced Epps-Pulley SIGReg.
- Architecture:
Backbone: timm ViT (CLS-pooled,
num_classes=0)Projector: MLP projection head
Loss:
invariance + (λ * SIGReg)
Centers are computed from global-view projections only. The invariance term penalises the MSE between each view’s projection and the center. The SIGReg term is a sliced goodness-of-fit test that pushes projected embeddings toward an isotropic Gaussian, averaged over views.
- Parameters:
encoder_name – timm model name (e.g.,
"vit_base_patch16_224")projector – Optional projection head. When
None, a 3-layer BN+ReLU MLP (embed_dim → 2048 → 2048 → 512) is created.n_slices – Random projection directions for the goodness-of-fit test (default: 1024)
t_max – EP integration upper bound (default: 3.0)
n_points – EP quadrature nodes (default: 17)
lamb – SIGReg weight λ (default: 0.02)
pretrained – Load pretrained timm weights
Example:
model = LeJEPA("vit_base_patch16_224") images = torch.randn(4, 3, 224, 224) model.train() output = model( global_views=[images, images], all_views=[images, images, images, images], ) output.loss.backward() model.eval() output = model(images=images) features = output.embedding # [4, 768]
Example with Lightning:
import lightning as pl from stable_pretraining.methods import LeJEPA class LeJEPALightning(pl.LightningModule): def __init__(self): super().__init__() self.model = LeJEPA("vit_base_patch16_224") def training_step(self, batch, batch_idx): views = [v["image"] for v in batch["views"]] output = self.model(global_views=views, all_views=views) self.log("loss", output.loss) return output.loss def configure_optimizers(self): return torch.optim.AdamW(self.parameters(), lr=1e-3)
- forward(global_views: list[Tensor] | None = None, local_views: list[Tensor] | None = None, images: Tensor | None = None) LeJEPAOutput[source]#
Same as
torch.nn.Module.forward().- Parameters:
*args – Whatever you decide to pass into the forward method.
**kwargs – Keyword arguments are also possible.
- Returns:
Your model’s output