OnlineKNN#
- class stable_pretraining.callbacks.OnlineKNN(name: str, input: str, target: str, queue_length: int, metrics: Dict, input_dim: Tuple[int, ...] | List[int] | int | None = None, target_dim: int | None = None, num_classes: int | None = None, k: int = 5, temperature: float = 0.07, chunk_size: int = -1, distance_metric: Literal['euclidean', 'squared_euclidean', 'cosine', 'manhattan'] = 'euclidean', verbose: bool = None)[source]#
Bases:
CallbackWeighted K-Nearest Neighbors online evaluator using queue discovery.
This callback implements a weighted KNN classifier that evaluates the quality of learned representations during training. It automatically discovers or creates OnlineQueue callbacks to maintain circular buffers of features and labels, then uses this cached data to compute KNN predictions during validation.
The KNN evaluation is performed by: 1. Finding k nearest neighbors in the feature space 2. Weighting neighbors by inverse distance with temperature scaling 3. Using weighted voting to produce class predictions 4. Computing specified metrics on the predictions
Note
Auto-creates its own input and target
OnlineQueuecallbacks if none with matching keys are registered, so users typically only need to addOnlineKNNitself. Pass a manually-registeredOnlineQueueonly to override the default queue length or share a queue across multiple consumers.- Parameters:
name – Unique identifier for this callback instance. Used for logging and storing metrics.
input – Key in batch dict containing input features to evaluate.
target – Key in batch dict containing ground truth target labels.
queue_length – Size of the circular buffer for caching features and labels. Larger values provide more representative samples but use more memory.
metrics – Dictionary of metrics to compute during validation. Keys are metric names, values are metric instances (e.g., torchmetrics.Accuracy).
input_dim – Expected dimensionality of input features. Can be int, tuple/list (will be flattened to product), or None to accept any dimension.
target_dim – Expected dimensionality of targets. None accepts any dimension.
num_classes – Total number of classes in the dataset. If
None(default), the class count is inferred from the maximum label observed in the queue and current batch. Always pass this explicitly when possible: inference can produce a count smaller than the true number of classes when the queue has not yet seen every class (early training, small queue, many classes), which causes the prediction tensor to be narrower than the metric expects (e.g.,torchmetrics.MulticlassAccuracy(10)crashes if predictions are shape(B, 7)instead of(B, 10)).k – Number of nearest neighbors to consider for voting. Default is 5.
temperature – Temperature parameter for distance weighting. Lower values give more weight to closer neighbors. Default is 0.07.
chunk_size – Batch size for memory-efficient distance computation. Set to -1 to compute all distances at once. Default is -1.
distance_metric – Distance metric for finding nearest neighbors. Options are ‘euclidean’, ‘squared_euclidean’, ‘cosine’, ‘manhattan’. Default is ‘euclidean’.
verbose – If
True, log extra per-step detail.Noneinherits the globalsptverbosity setting.
- Raises:
ValueError – If k <= 0, temperature <= 0, or chunk_size is invalid.
Note
The callback automatically handles distributed training by gathering data
Mixed precision is supported through automatic dtype conversion
Predictions are stored in batch dict with key ‘{name}_preds’
Metrics are logged with prefix ‘eval/{name}_’
- on_validation_batch_end(trainer: Trainer, pl_module: LightningModule, outputs: Dict, batch: Dict, batch_idx: int, dataloader_idx: int = 0) None[source]#
Compute KNN predictions during validation.