LeJEPA

LeJEPA#

class stable_pretraining.methods.LeJEPA(encoder_name: str = 'vit_base_patch16_224', projector: Module | None = None, n_slices: int = 1024, t_max: float = 3.0, n_points: int = 17, lamb: float = 0.02, pretrained: bool = False, drop_path_rate: float = 0.1)[source]#

Bases: Module

LeJEPA: multi-view invariance + sliced Epps-Pulley SIGReg.

Architecture:
  • Backbone: timm ViT (CLS-pooled, num_classes=0)

  • Projector: MLP projection head

  • Loss: invariance + * SIGReg)

Centers are computed from global-view projections only. The invariance term penalises the MSE between each view’s projection and the center. The SIGReg term is a sliced goodness-of-fit test that pushes projected embeddings toward an isotropic Gaussian, averaged over views.

Parameters:
  • encoder_name – timm model name (e.g., "vit_base_patch16_224")

  • projector – Optional projection head. When None, a 3-layer BN+ReLU MLP (embed_dim 2048 2048 512) is created.

  • n_slices – Random projection directions for the goodness-of-fit test (default: 1024)

  • t_max – EP integration upper bound (default: 3.0)

  • n_points – EP quadrature nodes (default: 17)

  • lamb – SIGReg weight λ (default: 0.02)

  • pretrained – Load pretrained timm weights

Example:

model = LeJEPA("vit_base_patch16_224")
images = torch.randn(4, 3, 224, 224)

model.train()
output = model(
    global_views=[images, images],
    all_views=[images, images, images, images],
)
output.loss.backward()

model.eval()
output = model(images=images)
features = output.embedding  # [4, 768]

Example with Lightning:

import lightning as pl
from stable_pretraining.methods import LeJEPA


class LeJEPALightning(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = LeJEPA("vit_base_patch16_224")

    def training_step(self, batch, batch_idx):
        views = [v["image"] for v in batch["views"]]
        output = self.model(global_views=views, all_views=views)
        self.log("loss", output.loss)
        return output.loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=1e-3)
forward(global_views: list[Tensor] | None = None, local_views: list[Tensor] | None = None, images: Tensor | None = None) LeJEPAOutput[source]#

Same as torch.nn.Module.forward().

Parameters:
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

Returns:

Your model’s output