Multi-layer Probe for Vision Models

Multi-layer Probe for Vision Models#

Train probes attached to multiple layers of a frozen backbone to monitor representation quality across depth.

import hydra
import lightning as pl
import torch
import torchmetrics
from typing import Dict, List, Tuple
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger  # type: ignore
from omegaconf import DictConfig
from torch import nn
from transformers import (
    AutoImageProcessor,
    AutoModel,
    AutoModelForZeroShotImageClassification,
    AutoProcessor,
)

import stable_pretraining as spt
from stable_pretraining.data import transforms


# -----------------------------
# Model registry
# -----------------------------


MODEL_ZOO = {
    "DINOv2": {
        "processor_cls": AutoImageProcessor,
        "processor_name": "facebook/dinov2-base",
        "model_cls": AutoModel,
        "model_name": "facebook/dinov2-base",
        "pooling": "cls",
        "probe_skip": 1,
    },
    "DINOv3": {
        "processor_cls": AutoImageProcessor,
        "processor_name": "facebook/dinov3-vitb16-pretrain-lvd1689m",
        "model_cls": AutoModel,
        "model_name": "facebook/dinov3-vitb16-pretrain-lvd1689m",
        "pooling": "cls",
        "probe_skip": 1,
    },
    "MetaCLIP": {
        "processor_cls": AutoProcessor,
        "processor_name": "facebook/metaclip-b16-400m",
        "model_cls": AutoModelForZeroShotImageClassification,
        "model_name": "facebook/metaclip-b16-400m",
        "pooling": "mean",
        "probe_skip": 1,
    },
    "IJEPA-1k": {
        "processor_cls": AutoImageProcessor,
        "processor_name": "facebook/ijepa_vith14_1k",
        "model_cls": AutoModel,
        "model_name": "facebook/ijepa_vith14_1k",
        "pooling": "mean",
        "probe_skip": 1,
    },
    "IJEPA-22k": {
        "processor_cls": AutoImageProcessor,
        "processor_name": "facebook/ijepa_vith14_22k",
        "model_cls": AutoModel,
        "model_name": "facebook/ijepa_vith14_22k",
        "pooling": "mean",
        "probe_skip": 1,
    },
}


# -----------------------------
# Utilities
# -----------------------------


def build_datasets() -> Tuple[torch.utils.data.Dataset, torch.utils.data.Dataset]:
    # Load the Hugging Face dataset
    train_dataset = spt.data.HFDataset(
        "clane9/imagenet-100",
        split="train",
        transform=transforms.RGB(),
    )

    val_dataset = spt.data.HFDataset(
        "clane9/imagenet-100",
        split="validation",
        transform=transforms.RGB(),
    )

    return train_dataset, val_dataset


def make_collate_fn(processor):
    def collate_fn(examples):
        images = [ex["image"] for ex in examples]
        labels = torch.tensor([ex["label"] for ex in examples], dtype=torch.long)
        batch = processor(images=images, return_tensors="pt")
        return {"images": batch, "label": labels}

    return collate_fn


def build_dataloaders(
    train_dataset,
    val_dataset,
    processor,
    batch_size: int = 128,
    num_workers: int = 6,
) -> Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]:
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        sampler=spt.data.sampler.RepeatedRandomSampler(train_dataset),
        batch_size=batch_size,
        num_workers=num_workers,
        drop_last=True,
        pin_memory=True,
        collate_fn=make_collate_fn(processor),
    )

    val_loader = torch.utils.data.DataLoader(
        dataset=val_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=True,
        collate_fn=make_collate_fn(processor),
    )
    return train_loader, val_loader


def load_backbone(model_name: str):
    spec = MODEL_ZOO[model_name]
    processor = spec["processor_cls"].from_pretrained(spec["processor_name"])  # type: ignore
    model = spec["model_cls"].from_pretrained(
        spec["model_name"], output_hidden_states=True
    )  # type: ignore

    config = model.config if "CLIP" not in model_name else model.config.vision_config
    emb_dim = config.hidden_size
    num_hidden_layers = config.num_hidden_layers
    pooling = spec["pooling"]
    probe_skip = spec.get("probe_skip", 1)

    if "CLIP" in model_name:
        model = model.vision_model

    return model, processor, emb_dim, num_hidden_layers, pooling, probe_skip


# -----------------------------
# Lightning-compatible `spt.Module`
# -----------------------------


def build_module(
    model, processor, transformer_block_indices: List[int], pooling: str
) -> spt.Module:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Define the forward used by `spt.Module`
    def forward(self, batch: Dict, stage: str):  # noqa: ARG001 (stage provided by spt)
        out: Dict[str, torch.Tensor] = {}

        # Preprocess & move to device
        # images = processor(batch["image"], return_tensors="pt")
        images = {
            k: v.to(device=device, non_blocking=True)
            for k, v in batch["images"].items()
        }

        outputs = self.model(**images, output_hidden_states=True)
        hiddens = outputs["hidden_states"]  # tuple: [embeddings, block1, block2, ...]

        # Mean-pool tokens per layer -> (B, D)
        for i in transformer_block_indices:
            x = hiddens[1 + i]

            if pooling == "cls":
                x = x[:, 0]
            elif pooling == "mean":
                x = x.mean(dim=1)
            else:
                raise ValueError(f"Unknown pooling type: {pooling}")

            out[f"embedding_layer_{i}"] = x.detach()
        return out

    module = spt.Module(
        model=model,  # spt.backbone.EvalOnly(model),  # freeze eval-only backbone
        forward=forward,
        processor=processor,
        optim=None,  # probes have their own optimizers
    )
    return module


# -----------------------------
# Probes
# -----------------------------


def build_probes(
    module, emb_dim: int, num_classes: int, transformer_block_indices: List[int]
):
    probes = []
    for i in transformer_block_indices:
        probes.append(
            spt.callbacks.OnlineProbe(
                module,
                target="label",
                name=f"linear_probe_block_{i}",
                input=f"embedding_layer_{i}",
                probe=nn.Sequential(
                    nn.BatchNorm1d(emb_dim),
                    nn.Linear(emb_dim, num_classes),
                ),
                loss=nn.CrossEntropyLoss(),
                metrics={
                    "top1": torchmetrics.classification.MulticlassAccuracy(num_classes),
                    "top5": torchmetrics.classification.MulticlassAccuracy(
                        num_classes, top_k=5
                    ),
                },
                optimizer={"type": "SGD", "lr": 1e-3},
                scheduler={"type": "CosineAnnealingLR", "T_max": 100},
            )
        )
    return probes


# -----------------------------
# Main
# -----------------------------


@hydra.main(config_path="config_examples", config_name="multi_probe")
def main(cfg: DictConfig):
    pl.seed_everything(cfg.seed, workers=True)

    # Backbone & module
    model, processor, emb_dim, num_layers, pooling, probe_skip = load_backbone(
        cfg.model
    )
    # Most ViT-like models have 12 blocks; adapt as needed
    transformer_block_indices = list(range(0, num_layers, probe_skip))
    module = build_module(model, processor, transformer_block_indices, pooling)

    # Data
    train_ds, val_ds = build_datasets()
    train_loader, val_loader = build_dataloaders(
        train_ds,
        val_ds,
        processor,
        batch_size=cfg.batch_size,
        num_workers=cfg.num_workers,
    )
    data = spt.data.DataModule(train=train_loader, val=val_loader)

    # Probes
    probes = build_probes(
        module,
        emb_dim=emb_dim,
        num_classes=100,
        transformer_block_indices=transformer_block_indices,
    )

    # Trainer
    precision = "16-mixed" if torch.cuda.is_available() else 32
    logger = None
    if cfg.use_wandb and WandbLogger is not None:
        logger = WandbLogger(project=cfg.project)

    checkpoint_callback = ModelCheckpoint(filename="ckpt", save_last=True)

    trainer = pl.Trainer(
        max_epochs=cfg.epochs,
        callbacks=probes + [checkpoint_callback],
        precision=precision,
        logger=logger,
        enable_checkpointing=True,
    )

    # Run
    manager = spt.Manager(trainer=trainer, module=module, data=data)
    manager()


if __name__ == "__main__":
    main()

Gallery generated by Sphinx-Gallery