IJEPA

IJEPA#

class stable_pretraining.methods.IJEPA(model_or_model_name: str | Module = 'vit_base_patch16_224', predictor_embed_dim: int = 384, predictor_depth: int = 6, num_targets: int = 4, target_scale: Tuple[float, float] = (0.15, 0.2), target_aspect_ratio: Tuple[float, float] = (0.75, 1.5), context_scale: Tuple[float, float] = (0.85, 1.0), ema_decay_start: float = 0.996, ema_decay_end: float = 1.0, pretrained: bool = False)[source]#

Bases: Module

I-JEPA: Image-based Joint-Embedding Predictive Architecture.

Architecture:
  • Context Encoder (student): Encodes visible/context patches

  • Target Encoder (teacher): EMA copy, encodes target patches

  • Predictor: Lightweight transformer predicting targets from context

The context encoder is wrapped with TeacherStudentWrapper, enabling automatic EMA updates via TeacherStudentCallback.

Parameters:
  • model_or_model_name – timm model name string or pre-instantiated nn.Module

  • predictor_embed_dim – Predictor hidden dimension (default: 384)

  • predictor_depth – Number of predictor blocks (default: 6)

  • num_targets – Number of target blocks to sample (default: 4)

  • target_scale – (min, max) fraction of patches per target block

  • target_aspect_ratio – (min, max) aspect ratio of target blocks

  • context_scale – (min, max) fraction of non-target patches as context

  • ema_decay_start – Initial EMA decay (default: 0.996)

  • ema_decay_end – Final EMA decay (default: 1.0)

  • pretrained – Load pretrained encoder weights

Example:

# Basic usage
model = IJEPA("vit_base_patch16_224")
images = torch.randn(4, 3, 224, 224)

# Training mode: predicts masked targets
model.train()
output = model(images)
output.loss.backward()

# Eval mode: encodes all patches (no masking)
model.eval()
output = model(images)
features = output.predictions  # [B, N, D]

Example with Lightning:

import lightning as pl
from stable_pretraining.callbacks import TeacherStudentCallback


class IJEPALightning(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = IJEPA("vit_base_patch16_224")

    def training_step(self, batch, batch_idx):
        images = batch[0] if isinstance(batch, (list, tuple)) else batch
        output = self.model(images)
        self.log("loss", output.loss)
        return output.loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=1.5e-4)


trainer = pl.Trainer(callbacks=[TeacherStudentCallback()])
trainer.fit(IJEPALightning(), dataloader)

Note

  • Use TeacherStudentCallback to handle EMA updates automatically

  • In eval mode, num_targets=0 and all patches are returned as context

  • Access trained encoder via model.encoder.student

forward(images: Tensor, embedding_source: str = 'teacher') IJEPAOutput[source]#

Forward pass.

In training mode:
  • Samples target blocks and context region via IJEPAMasking

  • Encodes context through student, targets through teacher (EMA)

  • Predicts target representations from context

  • Returns smooth L1 loss between predictions and targets

In eval mode:
  • No masking, all patches treated as context

  • Returns encoded features with zero loss

  • Always uses student encoder

Parameters:
  • images – Input images [B, C, H, W]

  • embedding_source – Which encoder to use for the embedding output. "teacher" (default) or "student". Only affects training mode; eval mode always uses student.

Returns:

IJEPAOutput with loss and representations