RankMe#

class stable_pretraining.callbacks.RankMe(name: str, target: str, queue_length: int, target_shape: int | Iterable[int], verbose: bool = None)[source]#

Bases: Callback

RankMe (effective rank) monitor using queue discovery.

RankMe measures the effective rank of feature representations by computing the exponential of the entropy of normalized singular values. This metric helps detect dimensional collapse in self-supervised learning.

Parameters:
  • name – Unique name for this callback instance. Used for logging and metric keys.

  • target – Key in the batch dict containing the feature embeddings to monitor.

  • queue_length – Size of the circular buffer for caching embeddings across validation batches. Larger values give a more representative estimate.

  • target_shape – Shape of the target embeddings — either a single int (e.g., 768) or a sequence whose product is used (e.g., (16, 48)).

  • verbose – If True, log entropy, top singular value, and condition number in addition to the RankMe score. None inherits the global spt verbosity setting.

on_validation_batch_end(trainer: Trainer, pl_module: LightningModule, outputs: dict, batch: dict, batch_idx: int, dataloader_idx: int = 0) None[source]#

Compute RankMe metric on the first validation batch only.

setup(trainer: Trainer, pl_module: LightningModule, stage: str) None[source]#

Find or create the queue callback for target features.

property state_key: str#

Unique identifier for this callback’s state during checkpointing.