ImageRetrieval#
- class stable_pretraining.callbacks.ImageRetrieval(pl_module, name: str, input: str, query_col: str, retrieval_col: str | List[str], metrics, features_dim: tuple[int] | list[int] | int | None = None, normalizer: str = None)[source]#
Bases:
CallbackImage Retrieval evaluator for self-supervised learning.
- The implementation follows:
- Parameters:
pl_module – The
spt.LightningModuleto evaluate against.name – Unique identifier (used as key in
callbacks_modulesandcallbacks_metrics). Two instances with the samenameraise.input – Key in
batchcontaining per-sample embeddings.query_col – Boolean column on the val dataset marking query rows.
retrieval_col – Single column name or list — each value is a list of gallery indices that count as relevant for that query.
metrics –
torchmetrics.retrieval.RetrievalMetricinstances keyed by display name.features_dim – Output dimension of the embedding. If
None(default) the dimension is inferred from the first validation batch. If provided, it must match what the model emits — a mismatch raises.normalizer –
"batch_norm","layer_norm", orNone(identity) applied to embeddings before L2-normalization. When using"batch_norm"or"layer_norm",features_dimmust be set explicitly because the normalizer module needs to be built at__init__time.
- on_validation_epoch_end(trainer: Trainer, pl_module: LightningModule) None[source]#
Called when the val epoch ends.
- on_validation_epoch_start(trainer: Trainer, pl_module: LightningModule) None[source]#
Eagerly allocate the embeds buffer if
features_dimwas given.Skipped otherwise — allocation happens lazily on the first batch inside the wrapped
validation_step.