Source code for stable_pretraining.callbacks.image_retrieval

import types
from typing import List, Optional, Union

import numpy as np
import torch
from lightning.pytorch import Callback, LightningModule, Trainer
from loguru import logger as logging
from torchmetrics.retrieval.base import RetrievalMetric

from .registry import log_dict as _spt_log_dict
from .utils import format_metrics_as_dict


def wrap_validation_step(fn, input, name, callback):
    """Wrap ``validation_step`` to populate the callback's embedding buffer.

    The wrapped function captures the ``ImageRetrieval`` instance via the
    closure so embeddings live on ``callback.embeds`` instead of mutating
    ``pl_module.embeds`` — keeping multiple ImageRetrieval instances isolated.
    """

    def ffn(
        self,
        batch,
        batch_idx,
        fn=fn,
        name=name,
        input=input,
        callback=callback,
    ):
        batch = fn(batch, batch_idx)

        with torch.no_grad():
            norm = self.callbacks_modules[name]["normalizer"](batch[input])
            norm = torch.nn.functional.normalize(norm, dim=1, p=2)

        idx = self.all_gather(batch["sample_idx"])
        norm = self.all_gather(norm)

        if self.local_rank == 0:
            # Lazy allocation: infer ``features_dim`` from the first batch's
            # normalized embeddings if the user didn't specify it upfront.
            if callback.embeds is None:
                inferred_dim = int(norm.shape[-1])
                if (
                    callback.features_dim is not None
                    and callback.features_dim != inferred_dim
                ):
                    raise ValueError(
                        f"ImageRetrieval[{name}]: features_dim="
                        f"{callback.features_dim} but first batch produced "
                        f"embeddings of dim {inferred_dim}. Either pass the "
                        f"correct features_dim or omit it for inference."
                    )
                dataset_size = len(self.trainer.datamodule.val.dataset)
                callback.embeds = torch.zeros(
                    (dataset_size, inferred_dim), device=norm.device
                )
                callback.features_dim = inferred_dim
            callback.embeds[idx] = norm

        return batch

    return ffn


[docs] class ImageRetrieval(Callback): """Image Retrieval evaluator for self-supervised learning. The implementation follows: 1. https://github.com/facebookresearch/dino/blob/main/eval_image_retrieval.py Args: pl_module: The ``spt.LightningModule`` to evaluate against. name: Unique identifier (used as key in ``callbacks_modules`` and ``callbacks_metrics``). Two instances with the same ``name`` raise. input: Key in ``batch`` containing per-sample embeddings. query_col: Boolean column on the val dataset marking query rows. retrieval_col: Single column name or list — each value is a list of gallery indices that count as relevant for that query. metrics: ``torchmetrics.retrieval.RetrievalMetric`` instances keyed by display name. features_dim: Output dimension of the embedding. If ``None`` (default) the dimension is inferred from the first validation batch. If provided, it must match what the model emits — a mismatch raises. normalizer: ``"batch_norm"``, ``"layer_norm"``, or ``None`` (identity) applied to embeddings before L2-normalization. When using ``"batch_norm"`` or ``"layer_norm"``, ``features_dim`` must be set explicitly because the normalizer module needs to be built at ``__init__`` time. """ NAME = "ImageRetrieval" def __init__( self, pl_module, name: str, input: str, query_col: str, retrieval_col: str | List[str], metrics, features_dim: Optional[Union[tuple[int], list[int], int]] = None, normalizer: str = None, ) -> None: logging.info(f"Setting up callback ({self.NAME})") logging.info(f"\t- {input=}") logging.info(f"\t- {query_col=}") logging.info("\t- caching modules into `callbacks_modules`") if name in pl_module.callbacks_modules: raise ValueError(f"{name=} already used in callbacks") if features_dim is not None and type(features_dim) in [list, tuple]: features_dim = int(np.prod(features_dim)) if normalizer is not None and normalizer not in ["batch_norm", "layer_norm"]: raise ValueError( "`normalizer` has to be one of `batch_norm` or `layer_norm`" ) # batch_norm / layer_norm need a known input dim at construction time; # identity is dim-agnostic and lets us defer. if normalizer in ("batch_norm", "layer_norm") and features_dim is None: raise ValueError( f"normalizer={normalizer!r} requires features_dim to be " "specified explicitly; only normalizer=None supports " "inference at first batch." ) if normalizer == "batch_norm": normalizer = torch.nn.BatchNorm1d(features_dim, affine=False) elif normalizer == "layer_norm": normalizer = torch.nn.LayerNorm( features_dim, elementwise_affine=False, bias=False ) else: normalizer = torch.nn.Identity() pl_module.callbacks_modules[name] = torch.nn.ModuleDict( { "normalizer": normalizer, } ) logging.info( f"`callbacks_modules` now contains ({list(pl_module.callbacks_modules.keys())})" ) if not isinstance(retrieval_col, list): retrieval_col = [retrieval_col] for k, metric in metrics.items(): if not isinstance(metric, RetrievalMetric): raise ValueError( f"Only `RetrievalMetric` is supported for {self.NAME} callback, but got {metric} for {k}" ) logging.info("\t- caching metrics into `callbacks_metrics`") pl_module.callbacks_metrics[name] = format_metrics_as_dict(metrics) self.name = name self.features_dim = features_dim self.query_col = query_col self.retrieval_col = retrieval_col # Embedding buffer lives on the callback instance (not on pl_module) # so multiple ImageRetrieval callbacks don't clobber each other's state. self.embeds: Optional[torch.Tensor] = None logging.info("\t- wrapping the `validation_step`") fn = wrap_validation_step(pl_module.validation_step, input, name, callback=self) pl_module.validation_step = types.MethodType(fn, pl_module) @property def state_key(self) -> str: """Unique identifier for this callback's state during checkpointing.""" return f"ImageRetrieval[name={self.name}]"
[docs] def on_validation_epoch_start( self, trainer: Trainer, pl_module: LightningModule ) -> None: """Eagerly allocate the embeds buffer if ``features_dim`` was given. Skipped otherwise — allocation happens lazily on the first batch inside the wrapped ``validation_step``. """ if pl_module.local_rank == 0 and self.features_dim is not None: val_dataset = pl_module.trainer.datamodule.val.dataset dataset_size = len(val_dataset) self.embeds = torch.zeros( (dataset_size, self.features_dim), device=pl_module.device )
[docs] def on_validation_epoch_end( self, trainer: Trainer, pl_module: LightningModule ) -> None: if pl_module.local_rank == 0: logging.info(f"Computing results for {self.name} callback") val_dataset = pl_module.trainer.datamodule.val.dataset.dataset if self.embeds is None: logging.warning( f"{self.name}: no embeddings collected (zero validation " "batches?). Skipping evaluation." ) return if len(self.embeds) != len(val_dataset): logging.warning( f"Expected {len(val_dataset)} embeddings, but got " f"{len(self.embeds)}. Skipping evaluation." ) return is_query = torch.tensor( val_dataset[self.query_col], device=pl_module.device ).squeeze() query_idx = torch.nonzero(is_query) query = self.embeds[is_query] gallery = self.embeds[~is_query] score = query @ gallery.t() # Pre-extract the retrieval-target columns once. Per-row # ``val_dataset[q_idx][col]`` triggers the dataset's transform # pipeline — re-decoding the image for every query just to read # an index list. Column-wise access on the underlying HF dataset # bypasses the transform pipeline entirely and returns native # Python lists, turning O(queries × image-decodes) into one # column fetch. ret_cache = {col: val_dataset[col] for col in self.retrieval_col} preds = [] targets = [] indexes = [] for idx, q_idx in enumerate(query_idx): # add query idx to the indexes indexes.append(q_idx.repeat(len(gallery))) # build target for query target = torch.zeros( len(gallery), dtype=torch.bool, device=pl_module.device ) row = int(q_idx.item()) for col in self.retrieval_col: ret_idx = ret_cache[col][row] if ret_idx: target[ret_idx] = True targets.append(target) preds.append(score[idx]) preds = torch.cat(preds) targets = torch.cat(targets) indexes = torch.cat(indexes) logs = {} for k, metric in pl_module.callbacks_metrics[self.name]["_val"].items(): res = metric(preds, targets, indexes=indexes) logs[f"eval/{self.name}_{k}"] = res.item() * 100 _spt_log_dict(logs, on_epoch=True, rank_zero_only=True) logging.info(f"Finished computing results for {self.name} callback") if torch.distributed.is_initialized(): torch.distributed.barrier() self.embeds = None