OnlineProbe#
- class stable_pretraining.callbacks.OnlineProbe(module: LightningModule, name: str, input: str, target: str, probe: Module, loss: callable = None, optimizer: str | dict | partial | Optimizer | None = None, scheduler: str | dict | partial | LRScheduler | None = None, accumulate_grad_batches: int = 1, gradient_clip_val: float = None, gradient_clip_algorithm: str = 'norm', metrics: dict | tuple | list | Metric | None = None, verbose: bool = None)[source]#
Bases:
TrainableCallbackOnline probe for evaluating learned representations during self-supervised training.
This callback implements the standard linear evaluation protocol by training a probe (typically a linear classifier) on top of frozen features from the main model. The probe is trained simultaneously with the main model but maintains its own optimizer, scheduler, and training loop. This allows monitoring representation quality throughout training without modifying the base model.
Key features
Automatic gradient detachment to prevent probe gradients affecting the main model.
Independent optimizer and scheduler management.
Support for gradient accumulation.
Mixed precision training compatibility through automatic dtype conversion.
Metric tracking and logging.
- Parameters:
module – The
spt.LightningModuleto probe.name – Unique identifier for this probe instance. Used for logging and storing metrics/modules.
input – Key in batch dict or outputs dict containing input features to probe.
target – Key in batch dict containing ground truth target labels.
probe – The probe module to train. Can be a
nn.Moduleinstance, callable that returns a module, or Hydra config to instantiate.loss – Loss function for probe training (e.g.,
nn.CrossEntropyLoss()).optimizer – Optimizer configuration for the probe. Accepted forms — string name (
"AdamW","SGD","LARS"), a dict like{"type": "AdamW", "lr": 1e-3, ...}, a pre-configuredfunctools.partial, an optimizer instance, or a callable.NoneusesLARS(lr=0.1, clip_lr=True, eta=0.02, exclude_bias_n_norm=True, weight_decay=0)(the standard for SSL linear probes, default).scheduler – Learning rate scheduler configuration. Accepted forms — string name (
"CosineAnnealingLR","StepLR"), a dict{"type": "CosineAnnealingLR", "T_max": 1000, ...}, a partial, a scheduler instance, or a callable.NoneusesConstantLR(factor=1.0)(constant LR, default).accumulate_grad_batches – Number of batches to accumulate gradients before optimizer step. Default is 1 (no accumulation).
gradient_clip_val – Maximum gradient norm for clipping probe gradients.
Nonedisables clipping. Default isNone.gradient_clip_algorithm – Norm type used for gradient clipping —
"norm"(L2) or"value"(element-wise). Default is"norm".metrics – Metrics to track during training/validation. Can be dict, list, tuple, or single metric instance.
verbose – If
True, log extra per-step detail.Noneinherits the globalsptverbosity setting.
Note
The probe module is stored in
pl_module.callbacks_modules[name].Metrics are stored in
pl_module.callbacks_metrics[name].Predictions are stored in batch dict with key
'{name}_preds'.Loss is logged as
'train/{name}_loss'.Metrics are logged with prefix
'train/{name}_'and'eval/{name}_'.
- configure_model(pl_module: LightningModule) Module[source]#
Initialize the probe module from configuration.