Source code for stable_pretraining.methods.lejepa

"""LeJEPA: Latent Embedding Joint-Embedding Predictive Architecture.

Self-supervised learning via multi-view invariance combined with a
sliced goodness-of-fit test (SIGReg) that pushes embeddings toward
an isotropic Gaussian.

References:
    Balestriero & LeCun. "LeJEPA: Provable and Scalable
    Self-Supervised Learning Without the Heuristics." 2025.
    https://arxiv.org/abs/2511.08544

Example::

    from stable_pretraining.methods.lejepa import LeJEPA

    model = LeJEPA("vit_small_patch16_224")

    global_images = [torch.randn(4, 3, 224, 224)] * 2
    all_images = [torch.randn(4, 3, 224, 224)] * 6
    model.train()
    output = model(global_images, all_images)
    output.loss.backward()

    model.eval()
    output = model(images=torch.randn(4, 3, 224, 224))
    features = output.embedding  # [N, D]
"""

from dataclasses import dataclass
from transformers.utils import ModelOutput
from typing import Optional

import timm
import torch
import torch.nn as nn
from torch.distributed.nn import all_reduce

from stable_pretraining import Module
from stable_pretraining.backbone import MLP


[docs] class EppsPulley(nn.Module): """Epps-Pulley goodness-of-fit test for univariate normality. Projects data onto a grid of points and computes the Epps-Pulley statistic. :param t_max: Integration upper bound. :param n_points: Number of integration points. """ def __init__(self, t_max: float = 3.0, n_points: int = 17): super().__init__() assert n_points % 2 == 1 self._is_ddp = ( torch.distributed.is_available() and torch.distributed.is_initialized() ) self.world_size = torch.distributed.get_world_size() if self._is_ddp else 1 t = torch.linspace(0, t_max, n_points) dt = t_max / (n_points - 1) self.register_buffer("t", t) phi = (-0.5 * t**2).exp() self.register_buffer("phi", phi) weights = torch.full((n_points,), 2 * dt) weights[[0, -1]] = dt self.register_buffer("weights", weights * phi)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """:param x: Samples [N, S] (N samples, S slices). :return: Per-slice statistic [S]. """ N = x.size(0) x_t = x.unsqueeze(-1) * self.t cos_mean = x_t.cos().mean(0) sin_mean = x_t.sin().mean(0) if self._is_ddp: all_reduce(cos_mean, op=torch.distributed.ReduceOp.AVG) all_reduce(sin_mean, op=torch.distributed.ReduceOp.AVG) err = (cos_mean - self.phi).square() + sin_mean.square() return (err @ self.weights) * N * self.world_size
[docs] class SlicedEppsPulley(nn.Module): """Sliced Epps-Pulley goodness-of-fit test for multivariate normality. Projects data onto random 1-D directions and averages the univariate Epps-Pulley statistics. A synchronised step counter seeds the random projections so all DDP ranks sample identical directions. :param num_slices: Number of random 1-D projections. :param t_max: EP integration upper bound. :param n_points: EP quadrature nodes. """ def __init__(self, num_slices: int = 1024, t_max: float = 3.0, n_points: int = 17): super().__init__() self._is_ddp = ( torch.distributed.is_available() and torch.distributed.is_initialized() ) self.num_slices = num_slices self.ep = EppsPulley(t_max=t_max, n_points=n_points) self.register_buffer("global_step", torch.zeros((), dtype=torch.long))
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """:param x: Embeddings [N, D]. :return: Scalar mean EP statistic. """ with torch.no_grad(): step = self.global_step.clone() if self._is_ddp: # All ranks increment global_step in lockstep, so this # broadcast is redundant under normal synchronous training. # It is kept as a safety net against step drift from # uneven batches (e.g. drop_last=False). torch.distributed.broadcast(step, src=0) g = torch.Generator(device=x.device).manual_seed(step.item()) A = torch.randn(x.size(-1), self.num_slices, device=x.device, generator=g) A = A / A.norm(p=2, dim=0) self.global_step.add_(1) proj = x @ A return self.ep(proj).mean()
[docs] @dataclass class LeJEPAOutput(ModelOutput): """Output from LeJEPA forward pass. :ivar loss: Combined invariance + SIGReg loss (0 in eval mode). :ivar embedding: Backbone embeddings [V*N, D] (train) or [N, D] (eval). :ivar inv_loss: Invariance component. :ivar sigreg_loss: Epps-Pulley goodness-of-fit component. """ loss: torch.Tensor = None embedding: torch.Tensor = None inv_loss: torch.Tensor = None sigreg_loss: torch.Tensor = None
[docs] class LeJEPA(Module): """LeJEPA: multi-view invariance + sliced Epps-Pulley SIGReg. Architecture: - **Backbone**: timm ViT (CLS-pooled, ``num_classes=0``) - **Projector**: MLP projection head - **Loss**: ``invariance + (λ * SIGReg)`` Centers are computed from global-view projections only. The invariance term penalises the MSE between each view's projection and the center. The SIGReg term is a sliced goodness-of-fit test that pushes projected embeddings toward an isotropic Gaussian, averaged over views. :param encoder_name: timm model name (e.g., ``"vit_base_patch16_224"``) :param projector: Optional projection head. When ``None``, a 3-layer BN+ReLU MLP (``embed_dim → 2048 → 2048 → 512``) is created. :param n_slices: Random projection directions for the goodness-of-fit test (default: 1024) :param t_max: EP integration upper bound (default: 3.0) :param n_points: EP quadrature nodes (default: 17) :param lamb: SIGReg weight λ (default: 0.02) :param pretrained: Load pretrained timm weights Example:: model = LeJEPA("vit_base_patch16_224") images = torch.randn(4, 3, 224, 224) model.train() output = model( global_views=[images, images], all_views=[images, images, images, images], ) output.loss.backward() model.eval() output = model(images=images) features = output.embedding # [4, 768] Example with Lightning:: import lightning as pl from stable_pretraining.methods.lejepa import LeJEPA class LeJEPALightning(pl.LightningModule): def __init__(self): super().__init__() self.model = LeJEPA("vit_base_patch16_224") def training_step(self, batch, batch_idx): views = [v["image"] for v in batch["views"]] output = self.model(global_views=views, all_views=views) self.log("loss", output.loss) return output.loss def configure_optimizers(self): return torch.optim.AdamW(self.parameters(), lr=1e-3) """ def __init__( self, encoder_name: str = "vit_base_patch16_224", projector: Optional[nn.Module] = None, n_slices: int = 1024, t_max: float = 3.0, n_points: int = 17, lamb: float = 0.02, pretrained: bool = False, drop_path_rate: float = 0.1, ): super().__init__() self.backbone = timm.create_model( encoder_name, pretrained=pretrained, num_classes=0, **({"dynamic_img_size": True} if "vit" in encoder_name else {}), drop_path_rate=drop_path_rate, ) embed_dim = self.backbone.embed_dim if projector is None: projector = nn.Sequential( nn.Linear(embed_dim, 512, bias=True), MLP( in_channels=512, hidden_channels=[2048, 2048, 512], norm_layer="batch_norm", activation_layer=nn.ReLU, inplace=True, dropout=0.0, ), ) self.projector = projector self.sigreg = SlicedEppsPulley( num_slices=n_slices, t_max=t_max, n_points=n_points ) self.lamb = lamb self.embed_dim = embed_dim @staticmethod def _compute_loss( all_projected: torch.Tensor, n_global: int, sigreg: SlicedEppsPulley, lamb: float, ): """Compute the LeJEPA loss. :param all_projected: All view projections [V, N, K]. :param n_global: Number of global views. :param sigreg: SlicedEppsPulley module. :param lamb: SIGReg weight λ. :return: Tuple of (total_loss, inv_loss, sigreg_loss). """ centers = all_projected[:n_global].mean(0) # [N, K] inv_loss = (centers.unsqueeze(0) - all_projected).square().mean() sigreg_loss = sigreg(all_projected.reshape(-1, all_projected.size(-1))) loss = inv_loss + lamb * sigreg_loss return loss, inv_loss, sigreg_loss
[docs] def forward( self, global_views: Optional[list[torch.Tensor]] = None, local_views: Optional[list[torch.Tensor]] = None, images: Optional[torch.Tensor] = None, ) -> LeJEPAOutput: if self.training: assert global_views is not None and local_views is not None, ( "global_views and local_views must be provided in training mode" ) g_features = self.backbone(torch.cat(global_views)) l_features = self.backbone(torch.cat(local_views)) all_features = torch.cat([g_features, l_features]) all_projected = self.projector(all_features) bs = global_views[0].shape[0] n_views = len(global_views) + len(local_views) all_projected = all_projected.view(n_views, bs, -1) loss, inv_loss, sigreg_loss = self._compute_loss( all_projected, len(global_views), self.sigreg, self.lamb ) embedding = g_features.detach() return LeJEPAOutput( loss=loss, embedding=embedding, inv_loss=inv_loss, sigreg_loss=sigreg_loss, ) else: assert images is not None, "images must be provided in eval mode" embedding = self.backbone(images) zero = torch.tensor(0.0, device=images.device) return LeJEPAOutput( loss=zero, embedding=embedding, inv_loss=zero, sigreg_loss=zero, )