Source code for stable_pretraining.methods.simclr

"""SimCLR: Simple Contrastive Learning of Representations.

Self-supervised learning via maximizing agreement between two augmented views
of the same image using NT-Xent contrastive loss.

References:
    Chen et al. "A Simple Framework for Contrastive Learning of Visual
    Representations." ICML 2020. https://arxiv.org/abs/2002.05709

Example::

    from stable_pretraining.methods import SimCLR
    import lightning as pl

    model = SimCLR(encoder_name="vit_small_patch16_224")

    trainer = pl.Trainer(max_epochs=300)
    trainer.fit(model, dataloader)
"""

from dataclasses import dataclass
from typing import Optional, Sequence, Union

import torch
import torch.nn as nn
from transformers.utils import ModelOutput

from stable_pretraining import Module
from stable_pretraining.backbone import BatchNorm1dNoBias, from_timm
from stable_pretraining.losses import NTXEntLoss


@dataclass
class SimCLROutput(ModelOutput):
    """Output from SimCLR forward pass.

    :ivar loss: NT-Xent contrastive loss (0 in eval mode)
    :ivar embedding: Backbone features [B, D] in eval mode, [2B, D] in train mode
    :ivar projection: Projector outputs [2B, P] (None in eval mode)
    """

    loss: torch.Tensor = None
    embedding: torch.Tensor = None
    projection: Optional[torch.Tensor] = None


def _build_projector(
    in_dim: int,
    hidden_dims: Sequence[int],
    final_bn_no_bias: bool = True,
) -> nn.Module:
    """Standard SimCLR projector: Linear -> BN -> ReLU -> ... -> Linear -> BN(no bias).

    :param in_dim: Backbone output dimension
    :param hidden_dims: Sequence of hidden + output dimensions, e.g. (2048, 2048, 128)
    :param final_bn_no_bias: Use ``BatchNorm1dNoBias`` on the final layer (SimCLR original)
    """
    if len(hidden_dims) < 1:
        raise ValueError("hidden_dims must contain at least one entry (the output dim)")
    layers = []
    prev = in_dim
    for i, dim in enumerate(hidden_dims):
        is_last = i == len(hidden_dims) - 1
        layers.append(nn.Linear(prev, dim, bias=False))
        if is_last:
            layers.append(
                BatchNorm1dNoBias(dim) if final_bn_no_bias else nn.BatchNorm1d(dim)
            )
        else:
            layers.append(nn.BatchNorm1d(dim))
            layers.append(nn.ReLU(inplace=True))
        prev = dim
    return nn.Sequential(*layers)


[docs] class SimCLR(Module): """SimCLR: contrastive joint-embedding self-supervised learning. Architecture: - **Backbone**: any feature extractor producing a flat [B, D] embedding (timm ViT/ResNet with the head removed) - **Projector**: 2- or 3-layer MLP mapping features to the contrastive space - **Loss**: NT-Xent (normalised temperature-scaled cross entropy) :param encoder_name: timm model name (e.g. ``"vit_small_patch16_224"``, ``"resnet50"``) or a pre-instantiated ``nn.Module`` whose ``forward`` returns a ``[B, D]`` tensor. :param projector_dims: Hidden + output dimensions of the MLP projector. ``(2048, 2048, 128)`` matches the original SimCLR ResNet50 recipe; for ViT backbones the input is taken from the encoder embed_dim. :param temperature: Temperature for NT-Xent (0.5 in original SimCLR; 0.1 is common for harder/larger batches). :param low_resolution: Adapt first conv for 32x32 inputs (CIFAR-style). :param pretrained: Load pretrained timm weights for the encoder. Example:: model = SimCLR( encoder_name="vit_small_patch16_224", projector_dims=(2048, 2048, 256), temperature=0.2, ) v1 = torch.randn(64, 3, 224, 224) v2 = torch.randn(64, 3, 224, 224) out = model(v1, v2) out.loss.backward() # eval: single view, no loss model.eval() out = model(v1) features = out.embedding # [64, embed_dim] """ def __init__( self, encoder_name: Union[str, nn.Module] = "vit_small_patch16_224", projector_dims: Sequence[int] = (2048, 2048, 256), temperature: float = 0.5, low_resolution: bool = False, pretrained: bool = False, ): super().__init__() if isinstance(encoder_name, str): self.backbone = from_timm( encoder_name, num_classes=0, low_resolution=low_resolution, pretrained=pretrained, ) else: self.backbone = encoder_name # Detect embedding dimension by running a tiny dummy input with torch.no_grad(): dummy = torch.zeros(1, 3, 224, 224) embed_dim = self.backbone(dummy).shape[-1] self.embed_dim = embed_dim self.projector = _build_projector(embed_dim, list(projector_dims)) self.simclr_loss = NTXEntLoss(temperature=temperature)
[docs] def forward( self, view1: torch.Tensor, view2: Optional[torch.Tensor] = None, ) -> SimCLROutput: """Forward pass. :param view1: First augmented view [B, C, H, W] (or single view at eval). :param view2: Second augmented view [B, C, H, W]. If ``None``, returns only the backbone embedding (eval mode). :return: :class:`SimCLROutput`. """ if view2 is None: embedding = self.backbone(view1) return SimCLROutput( loss=torch.zeros((), device=embedding.device, dtype=embedding.dtype), embedding=embedding, projection=None, ) h1 = self.backbone(view1) h2 = self.backbone(view2) z1 = self.projector(h1) z2 = self.projector(h2) loss = self.simclr_loss(z1, z2) return SimCLROutput( loss=loss, embedding=torch.cat([h1, h2], dim=0), projection=torch.cat([z1, z2], dim=0), )