SimCLR

SimCLR#

class stable_pretraining.methods.SimCLR(encoder_name: str | Module = 'vit_small_patch16_224', projector_dims: Sequence[int] = (2048, 2048, 256), temperature: float = 0.5, low_resolution: bool = False, pretrained: bool = False)[source]#

Bases: Module

SimCLR: contrastive joint-embedding self-supervised learning.

Architecture:
  • Backbone: any feature extractor producing a flat [B, D] embedding (timm ViT/ResNet with the head removed)

  • Projector: 2- or 3-layer MLP mapping features to the contrastive space

  • Loss: NT-Xent (normalised temperature-scaled cross entropy)

Parameters:
  • encoder_name – timm model name (e.g. "vit_small_patch16_224", "resnet50") or a pre-instantiated nn.Module whose forward returns a [B, D] tensor.

  • projector_dims – Hidden + output dimensions of the MLP projector. (2048, 2048, 128) matches the original SimCLR ResNet50 recipe; for ViT backbones the input is taken from the encoder embed_dim.

  • temperature – Temperature for NT-Xent (0.5 in original SimCLR; 0.1 is common for harder/larger batches).

  • low_resolution – Adapt first conv for 32x32 inputs (CIFAR-style).

  • pretrained – Load pretrained timm weights for the encoder.

Example:

model = SimCLR(
    encoder_name="vit_small_patch16_224",
    projector_dims=(2048, 2048, 256),
    temperature=0.2,
)

v1 = torch.randn(64, 3, 224, 224)
v2 = torch.randn(64, 3, 224, 224)
out = model(v1, v2)
out.loss.backward()

# eval: single view, no loss
model.eval()
out = model(v1)
features = out.embedding  # [64, embed_dim]
forward(view1: Tensor, view2: Tensor | None = None) SimCLROutput[source]#

Forward pass.

Parameters:
  • view1 – First augmented view [B, C, H, W] (or single view at eval).

  • view2 – Second augmented view [B, C, H, W]. If None, returns only the backbone embedding (eval mode).

Returns:

SimCLROutput.