ImageRetrieval#

class stable_pretraining.callbacks.ImageRetrieval(pl_module, name: str, input: str, query_col: str, retrieval_col: str | List[str], metrics, features_dim: tuple[int] | list[int] | int | None = None, normalizer: str = None)[source]#

Bases: Callback

Image Retrieval evaluator for self-supervised learning.

The implementation follows:
  1. facebookresearch/dino

Parameters:
  • pl_module – The spt.LightningModule to evaluate against.

  • name – Unique identifier (used as key in callbacks_modules and callbacks_metrics). Two instances with the same name raise.

  • input – Key in batch containing per-sample embeddings.

  • query_col – Boolean column on the val dataset marking query rows.

  • retrieval_col – Single column name or list — each value is a list of gallery indices that count as relevant for that query.

  • metricstorchmetrics.retrieval.RetrievalMetric instances keyed by display name.

  • features_dim – Output dimension of the embedding. If None (default) the dimension is inferred from the first validation batch. If provided, it must match what the model emits — a mismatch raises.

  • normalizer"batch_norm", "layer_norm", or None (identity) applied to embeddings before L2-normalization. When using "batch_norm" or "layer_norm", features_dim must be set explicitly because the normalizer module needs to be built at __init__ time.

on_validation_epoch_end(trainer: Trainer, pl_module: LightningModule) None[source]#

Called when the val epoch ends.

on_validation_epoch_start(trainer: Trainer, pl_module: LightningModule) None[source]#

Eagerly allocate the embeds buffer if features_dim was given.

Skipped otherwise — allocation happens lazily on the first batch inside the wrapped validation_step.

property state_key: str#

Unique identifier for this callback’s state during checkpointing.