Note
Go to the end to download the full example code.
Multi-layer probe for vision models.
import argparse
from typing import Dict, List, Tuple
import hydra
import lightning as pl
import torch
import torchmetrics
import torchvision
from datasets import load_dataset
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger # type: ignore
from omegaconf import DictConfig
from torch import nn
from transformers import (
AutoImageProcessor,
AutoModel,
AutoModelForZeroShotImageClassification,
AutoProcessor,
)
import stable_pretraining as spt
from stable_pretraining.data import transforms
# -----------------------------
# Model registry
# -----------------------------
MODEL_ZOO = {
"DINOv2": {
"processor_cls": AutoImageProcessor,
"processor_name": "facebook/dinov2-base",
"model_cls": AutoModel,
"model_name": "facebook/dinov2-base",
"pooling": "cls",
"probe_skip": 1,
},
"DINOv3": {
"processor_cls": AutoImageProcessor,
"processor_name": "facebook/dinov3-vitb16-pretrain-lvd1689m",
"model_cls": AutoModel,
"model_name": "facebook/dinov3-vitb16-pretrain-lvd1689m",
"pooling": "cls",
"probe_skip": 1,
},
"MetaCLIP": {
"processor_cls": AutoProcessor,
"processor_name": "facebook/metaclip-b16-400m",
"model_cls": AutoModelForZeroShotImageClassification,
"model_name": "facebook/metaclip-b16-400m",
"pooling": "mean",
"probe_skip": 1,
},
"IJEPA-1k": {
"processor_cls": AutoImageProcessor,
"processor_name": "facebook/ijepa_vith14_1k",
"model_cls": AutoModel,
"model_name": "facebook/ijepa_vith14_1k",
"pooling": "mean",
"probe_skip": 1,
},
"IJEPA-22k": {
"processor_cls": AutoImageProcessor,
"processor_name": "facebook/ijepa_vith14_22k",
"model_cls": AutoModel,
"model_name": "facebook/ijepa_vith14_22k",
"pooling": "mean",
"probe_skip": 1,
},
}
# -----------------------------
# Utilities
# -----------------------------
def build_datasets() -> Tuple[torch.utils.data.Dataset, torch.utils.data.Dataset]:
# Load the Hugging Face dataset
train_dataset = spt.data.HFDataset(
"clane9/imagenet-100",
split="train",
transform=transforms.RGB(),
)
val_dataset = spt.data.HFDataset(
"clane9/imagenet-100",
split="validation",
transform=transforms.RGB(),
)
return train_dataset, val_dataset
def make_collate_fn(processor):
def collate_fn(examples):
images = [ex["image"] for ex in examples]
labels = torch.tensor([ex["label"] for ex in examples], dtype=torch.long)
batch = processor(images=images, return_tensors="pt")
return {"images": batch, "label": labels}
return collate_fn
def build_dataloaders(
train_dataset,
val_dataset,
processor,
batch_size: int = 128,
num_workers: int = 6,
) -> Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]:
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
sampler=spt.data.sampler.RepeatedRandomSampler(train_dataset),
batch_size=batch_size,
num_workers=num_workers,
drop_last=True,
pin_memory=True,
collate_fn=make_collate_fn(processor),
)
val_loader = torch.utils.data.DataLoader(
dataset=val_dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=True,
collate_fn=make_collate_fn(processor),
)
return train_loader, val_loader
def load_backbone(model_name: str):
spec = MODEL_ZOO[model_name]
processor = spec["processor_cls"].from_pretrained(spec["processor_name"]) # type: ignore
model = spec["model_cls"].from_pretrained(
spec["model_name"], output_hidden_states=True
) # type: ignore
config = model.config if "CLIP" not in model_name else model.config.vision_config
emb_dim = config.hidden_size
num_hidden_layers = config.num_hidden_layers
pooling = spec["pooling"]
probe_skip = spec.get("probe_skip", 1)
if "CLIP" in model_name:
model = model.vision_model
return model, processor, emb_dim, num_hidden_layers, pooling, probe_skip
# -----------------------------
# Lightning-compatible `spt.Module`
# -----------------------------
def build_module(
model, processor, transformer_block_indices: List[int], pooling: str
) -> spt.Module:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define the forward used by `spt.Module`
def forward(self, batch: Dict, stage: str): # noqa: ARG001 (stage provided by spt)
out: Dict[str, torch.Tensor] = {}
# Preprocess & move to device
# images = processor(batch["image"], return_tensors="pt")
images = {
k: v.to(device=device, non_blocking=True)
for k, v in batch["images"].items()
}
outputs = self.model(**images, output_hidden_states=True)
hiddens = outputs["hidden_states"] # tuple: [embeddings, block1, block2, ...]
# Mean-pool tokens per layer -> (B, D)
for i in transformer_block_indices:
x = hiddens[1 + i]
if pooling == "cls":
x = x[:, 0]
elif pooling == "mean":
x = x.mean(dim=1)
else:
raise ValueError(f"Unknown pooling type: {pooling}")
out[f"embedding_layer_{i}"] = x.detach()
return out
module = spt.Module(
model=model, # spt.backbone.EvalOnly(model), # freeze eval-only backbone
forward=forward,
processor=processor,
optim=None, # probes have their own optimizers
)
return module
# -----------------------------
# Probes
# -----------------------------
def build_probes(
module, emb_dim: int, num_classes: int, transformer_block_indices: List[int]
):
probes = []
for i in transformer_block_indices:
probes.append(
spt.callbacks.OnlineProbe(
module,
target="label",
name=f"linear_probe_block_{i}",
input=f"embedding_layer_{i}",
probe=nn.Sequential(
nn.BatchNorm1d(emb_dim),
nn.Linear(emb_dim, num_classes),
),
loss_fn=nn.CrossEntropyLoss(),
metrics={
"top1": torchmetrics.classification.MulticlassAccuracy(num_classes),
"top5": torchmetrics.classification.MulticlassAccuracy(
num_classes, top_k=5
),
},
optimizer={"type": "SGD", "lr": 1e-3},
scheduler={"type": "CosineAnnealingLR", "T_max": 100},
)
)
return probes
# -----------------------------
# Main
# -----------------------------
@hydra.main(config_path="config_examples", config_name="multi_probe")
def main(cfg: DictConfig):
pl.seed_everything(cfg.seed, workers=True)
# Backbone & module
model, processor, emb_dim, num_layers, pooling, probe_skip = load_backbone(
cfg.model
)
# Most ViT-like models have 12 blocks; adapt as needed
transformer_block_indices = list(range(0, num_layers, probe_skip))
module = build_module(model, processor, transformer_block_indices, pooling)
# Data
train_ds, val_ds = build_datasets()
train_loader, val_loader = build_dataloaders(
train_ds,
val_ds,
processor,
batch_size=cfg.batch_size,
num_workers=cfg.num_workers,
)
data = spt.data.DataModule(train=train_loader, val=val_loader)
# Probes
probes = build_probes(
module,
emb_dim=emb_dim,
num_classes=100,
transformer_block_indices=transformer_block_indices,
)
# Trainer
precision = "16-mixed" if torch.cuda.is_available() else 32
logger = None
if cfg.use_wandb and WandbLogger is not None:
logger = WandbLogger(project=cfg.project)
checkpoint_callback = ModelCheckpoint(filename="ckpt", save_last=True)
trainer = pl.Trainer(
max_epochs=cfg.epochs,
callbacks=probes + [checkpoint_callback],
precision=precision,
logger=logger,
enable_checkpointing=True,
)
# Run
manager = spt.Manager(trainer=trainer, module=module, data=data)
manager()
if __name__ == "__main__":
main()