SimCLR#
- class stable_pretraining.methods.SimCLR(encoder_name: str | Module = 'vit_small_patch16_224', projector_dims: Sequence[int] = (2048, 2048, 256), temperature: float = 0.5, low_resolution: bool = False, pretrained: bool = False)[source]#
Bases:
ModuleSimCLR: contrastive joint-embedding self-supervised learning.
- Architecture:
Backbone: any feature extractor producing a flat [B, D] embedding (timm ViT/ResNet with the head removed)
Projector: 2- or 3-layer MLP mapping features to the contrastive space
Loss: NT-Xent (normalised temperature-scaled cross entropy)
- Parameters:
encoder_name – timm model name (e.g.
"vit_small_patch16_224","resnet50") or a pre-instantiatednn.Modulewhoseforwardreturns a[B, D]tensor.projector_dims – Hidden + output dimensions of the MLP projector.
(2048, 2048, 128)matches the original SimCLR ResNet50 recipe; for ViT backbones the input is taken from the encoder embed_dim.temperature – Temperature for NT-Xent (0.5 in original SimCLR; 0.1 is common for harder/larger batches).
low_resolution – Adapt first conv for 32x32 inputs (CIFAR-style).
pretrained – Load pretrained timm weights for the encoder.
Example:
model = SimCLR( encoder_name="vit_small_patch16_224", projector_dims=(2048, 2048, 256), temperature=0.2, ) v1 = torch.randn(64, 3, 224, 224) v2 = torch.randn(64, 3, 224, 224) out = model(v1, v2) out.loss.backward() # eval: single view, no loss model.eval() out = model(v1) features = out.embedding # [64, embed_dim]