Source code for stable_pretraining.methods.ijepa

"""I-JEPA: Image-based Joint-Embedding Predictive Architecture.

Self-supervised learning via predicting target patch representations
from context patch representations using a lightweight predictor.

References:
    Assran et al. "Self-Supervised Learning from Images with a
    Joint-Embedding Predictive Architecture." CVPR 2023.
    https://arxiv.org/abs/2301.08243

Example::

    from stable_pretraining.backbone import IJEPA
    from stable_pretraining.callbacks import TeacherStudentCallback
    import lightning as pl

    # Create model
    model = IJEPA(
        model_or_model_name="vit_base_patch16_224",
        predictor_embed_dim=384,
        predictor_depth=6,
        num_targets=4,
    )

    # Training with PyTorch Lightning
    trainer = pl.Trainer(
        max_epochs=300,
        callbacks=[TeacherStudentCallback()],  # Handles EMA updates
    )
    trainer.fit(model, dataloader)

    # Get encoder for downstream tasks
    encoder = model.encoder.student
"""

from dataclasses import dataclass
from typing import Tuple, Union

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from stable_pretraining.backbone import (
    FlexibleTransformer,
    IJEPAMasking,
    MaskedEncoder,
    TeacherStudentWrapper,
)
from stable_pretraining import Module
from transformers.utils import ModelOutput


[docs] @dataclass class IJEPAOutput(ModelOutput): """Output from IJEPA forward pass. :ivar loss: Prediction loss (0 in eval mode) :ivar embedding: Patch embeddings [B, N, D] for downstream use :ivar predictions: Predicted representations [B, N_tgt, D] (or context in eval) :ivar targets: Target representations [B, N_tgt, D] (or context in eval) :ivar num_targets: Number of target patches (0 in eval) :ivar num_context: Number of context patches (all patches in eval) """ loss: torch.Tensor = None embedding: torch.Tensor = None predictions: torch.Tensor = None targets: torch.Tensor = None num_targets: int = None num_context: int = None
[docs] class IJEPA(Module): """I-JEPA: Image-based Joint-Embedding Predictive Architecture. Architecture: - **Context Encoder** (student): Encodes visible/context patches - **Target Encoder** (teacher): EMA copy, encodes target patches - **Predictor**: Lightweight transformer predicting targets from context The context encoder is wrapped with :class:`TeacherStudentWrapper`, enabling automatic EMA updates via :class:`TeacherStudentCallback`. :param model_or_model_name: timm model name string or pre-instantiated nn.Module :param predictor_embed_dim: Predictor hidden dimension (default: 384) :param predictor_depth: Number of predictor blocks (default: 6) :param num_targets: Number of target blocks to sample (default: 4) :param target_scale: (min, max) fraction of patches per target block :param target_aspect_ratio: (min, max) aspect ratio of target blocks :param context_scale: (min, max) fraction of non-target patches as context :param ema_decay_start: Initial EMA decay (default: 0.996) :param ema_decay_end: Final EMA decay (default: 1.0) :param pretrained: Load pretrained encoder weights Example:: # Basic usage model = IJEPA("vit_base_patch16_224") images = torch.randn(4, 3, 224, 224) # Training mode: predicts masked targets model.train() output = model(images) output.loss.backward() # Eval mode: encodes all patches (no masking) model.eval() output = model(images) features = output.predictions # [B, N, D] Example with Lightning:: import lightning as pl from stable_pretraining.callbacks import TeacherStudentCallback class IJEPALightning(pl.LightningModule): def __init__(self): super().__init__() self.model = IJEPA("vit_base_patch16_224") def training_step(self, batch, batch_idx): images = batch[0] if isinstance(batch, (list, tuple)) else batch output = self.model(images) self.log("loss", output.loss) return output.loss def configure_optimizers(self): return torch.optim.AdamW(self.parameters(), lr=1.5e-4) trainer = pl.Trainer(callbacks=[TeacherStudentCallback()]) trainer.fit(IJEPALightning(), dataloader) Note: - Use :class:`TeacherStudentCallback` to handle EMA updates automatically - In eval mode, ``num_targets=0`` and all patches are returned as context - Access trained encoder via ``model.encoder.student`` """ def __init__( self, model_or_model_name: Union[str, nn.Module] = "vit_base_patch16_224", predictor_embed_dim: int = 384, predictor_depth: int = 6, num_targets: int = 4, target_scale: Tuple[float, float] = (0.15, 0.2), target_aspect_ratio: Tuple[float, float] = (0.75, 1.5), context_scale: Tuple[float, float] = (0.85, 1.0), ema_decay_start: float = 0.996, ema_decay_end: float = 1.0, pretrained: bool = False, ): super().__init__() # Encoder with EMA wrapper (enables TeacherStudentCallback) base_encoder = MaskedEncoder( model_or_model_name, masking=None, pretrained=pretrained, ) self.encoder = TeacherStudentWrapper( base_encoder, warm_init=True, base_ema_coefficient=ema_decay_start, final_ema_coefficient=ema_decay_end, ) embed_dim = base_encoder.embed_dim num_patches = base_encoder.default_grid_h * base_encoder.default_grid_w # Lightweight predictor: cross-attention from target queries to context self.predictor = FlexibleTransformer( input_dim=embed_dim, hidden_dim=predictor_embed_dim, output_dim=embed_dim, num_patches=num_patches, depth=predictor_depth, num_heads=max(1, predictor_embed_dim // 64), self_attn=True, cross_attn=False, add_mask_token=True, use_adaln=False, num_prefix_tokens=0, zero_init_output=False, ) # I-JEPA multi-block masking self.masking = IJEPAMasking( num_targets=num_targets, target_scale=target_scale, target_aspect_ratio=target_aspect_ratio, context_scale=context_scale, ) self.embed_dim = embed_dim self._fix_init_weight() def _encode( self, patches: torch.Tensor, indices: torch.Tensor, grid_h: int, grid_w: int, encoder: MaskedEncoder, ) -> torch.Tensor: """Encode patches at specified indices. :param patches: All patch embeddings [B, N, D] :param indices: Indices to encode [B, K] :param grid_h: Patch grid height :param grid_w: Patch grid width :param encoder: Encoder to use (student or teacher) :return: Encoded representations [B, K, D] """ B, _, D = patches.shape # patch -> posemb -> mask -> block -> norm _, pos = encoder._get_pos_embed(grid_h, grid_w) x = patches + pos.expand(B, -1, -1) x = torch.gather(x, 1, indices.unsqueeze(-1).expand(-1, -1, D)) x = encoder.vit.pos_drop(x) x = encoder.vit.blocks(x) return encoder.vit.norm(x)
[docs] def forward( self, images: torch.Tensor, embedding_source: str = "teacher" ) -> IJEPAOutput: """Forward pass. In training mode: - Samples target blocks and context region via :class:`IJEPAMasking` - Encodes context through student, targets through teacher (EMA) - Predicts target representations from context - Returns smooth L1 loss between predictions and targets In eval mode: - No masking, all patches treated as context - Returns encoded features with zero loss - Always uses student encoder :param images: Input images [B, C, H, W] :param embedding_source: Which encoder to use for the embedding output. ``"teacher"`` (default) or ``"student"``. Only affects training mode; eval mode always uses student. :return: :class:`IJEPAOutput` with loss and representations """ if embedding_source not in ("teacher", "student"): raise ValueError( f"embedding_source must be 'teacher' or 'student', got '{embedding_source}'" ) B = images.shape[0] grid_h, grid_w = self.encoder.student._get_grid_size(images) student_patches = self.encoder.student.patch_embed(images) teacher_patches = self.encoder.teacher.patch_embed(images) # Apply masking (returns all patches as context in eval mode) mask_out = self.masking(student_patches, grid_h, grid_w) if self.training: # Context: student sees only context patches context = self._encode( student_patches, mask_out.context_idx, grid_h, grid_w, self.encoder.student, ) with torch.no_grad(): # Teacher: full forward with ALL patches visible, then select targets all_idx = ( torch.arange(grid_h * grid_w, device=images.device) .unsqueeze(0) .expand(B, -1) ) teacher_full = self._encode( teacher_patches, all_idx, grid_h, grid_w, self.encoder.teacher ) # teacher's vit.norm already applied teacher_full_normed = F.layer_norm( teacher_full, [teacher_full.size(-1)], weight=None, bias=None, # extra norm but affine as per paper ) # Select target patches from the full encoding D = teacher_full.size(-1) targets = torch.gather( teacher_full_normed, 1, mask_out.target_idx.unsqueeze(-1).expand(-1, -1, D), ) # Embedding: reuse teacher_full (unnormed, full sequence) for the probe if embedding_source == "teacher": embedding = teacher_full else: embedding = self._encode( student_patches, all_idx, grid_h, grid_w, self.encoder.student ) # Predict target representations via joint self-attention on [context + mask tokens] N_tgt = mask_out.target_idx.shape[1] # Create dummy queries and just mask them all out queries = torch.zeros( B, N_tgt, self.embed_dim, device=images.device, dtype=context.dtype ) query_mask = torch.ones(B, N_tgt, device=images.device, dtype=torch.bool) predictions = self.predictor( context=context, queries=queries, context_idx=mask_out.context_idx, query_idx=mask_out.target_idx, query_mask=query_mask, ) loss = F.smooth_l1_loss(predictions, targets, beta=1.0) else: # Eval: encode all patches through student with torch.no_grad(): context = self._encode( student_patches, mask_out.context_idx, grid_h, grid_w, self.encoder.student, ) predictions = context targets = context embedding = context loss = torch.tensor(0.0, device=images.device) return IJEPAOutput( loss=loss, embedding=embedding, predictions=predictions, targets=targets, num_targets=mask_out.target_idx.shape[1], num_context=mask_out.context_idx.shape[1], )
def _fix_init_weight(self): """Rescale attention proj and MLP output weights by depth, matching I-JEPA init from the repo.""" def rescale(param, layer_id): param.div_(math.sqrt(2.0 * layer_id)) for encoder in (self.encoder.student, self.encoder.teacher): for layer_id, block in enumerate(encoder.vit.blocks): rescale(block.attn.proj.weight.data, layer_id + 1) rescale(block.mlp.fc2.weight.data, layer_id + 1) for layer_id, block in enumerate(self.predictor.blocks): rescale(block.attn.proj.weight.data, layer_id + 1) rescale(block.mlp.fc2.weight.data, layer_id + 1)