IJEPA#
- class stable_pretraining.methods.IJEPA(model_or_model_name: str | Module = 'vit_base_patch16_224', predictor_embed_dim: int = 384, predictor_depth: int = 6, num_targets: int = 4, target_scale: Tuple[float, float] = (0.15, 0.2), target_aspect_ratio: Tuple[float, float] = (0.75, 1.5), context_scale: Tuple[float, float] = (0.85, 1.0), ema_decay_start: float = 0.996, ema_decay_end: float = 1.0, pretrained: bool = False)[source]#
Bases:
ModuleI-JEPA: Image-based Joint-Embedding Predictive Architecture.
- Architecture:
Context Encoder (student): Encodes visible/context patches
Target Encoder (teacher): EMA copy, encodes target patches
Predictor: Lightweight transformer predicting targets from context
The context encoder is wrapped with
TeacherStudentWrapper, enabling automatic EMA updates viaTeacherStudentCallback.- Parameters:
model_or_model_name – timm model name string or pre-instantiated nn.Module
predictor_embed_dim – Predictor hidden dimension (default: 384)
predictor_depth – Number of predictor blocks (default: 6)
num_targets – Number of target blocks to sample (default: 4)
target_scale – (min, max) fraction of patches per target block
target_aspect_ratio – (min, max) aspect ratio of target blocks
context_scale – (min, max) fraction of non-target patches as context
ema_decay_start – Initial EMA decay (default: 0.996)
ema_decay_end – Final EMA decay (default: 1.0)
pretrained – Load pretrained encoder weights
Example:
# Basic usage model = IJEPA("vit_base_patch16_224") images = torch.randn(4, 3, 224, 224) # Training mode: predicts masked targets model.train() output = model(images) output.loss.backward() # Eval mode: encodes all patches (no masking) model.eval() output = model(images) features = output.predictions # [B, N, D]
Example with Lightning:
import lightning as pl from stable_pretraining.callbacks import TeacherStudentCallback class IJEPALightning(pl.LightningModule): def __init__(self): super().__init__() self.model = IJEPA("vit_base_patch16_224") def training_step(self, batch, batch_idx): images = batch[0] if isinstance(batch, (list, tuple)) else batch output = self.model(images) self.log("loss", output.loss) return output.loss def configure_optimizers(self): return torch.optim.AdamW(self.parameters(), lr=1.5e-4) trainer = pl.Trainer(callbacks=[TeacherStudentCallback()]) trainer.fit(IJEPALightning(), dataloader)
Note
Use
TeacherStudentCallbackto handle EMA updates automaticallyIn eval mode,
num_targets=0and all patches are returned as contextAccess trained encoder via
model.encoder.student
- forward(images: Tensor, embedding_source: str = 'teacher') IJEPAOutput[source]#
Forward pass.
- In training mode:
Samples target blocks and context region via
IJEPAMaskingEncodes context through student, targets through teacher (EMA)
Predicts target representations from context
Returns smooth L1 loss between predictions and targets
- In eval mode:
No masking, all patches treated as context
Returns encoded features with zero loss
Always uses student encoder
- Parameters:
images – Input images [B, C, H, W]
embedding_source – Which encoder to use for the embedding output.
"teacher"(default) or"student". Only affects training mode; eval mode always uses student.
- Returns:
IJEPAOutputwith loss and representations