Source code for stable_pretraining.methods.simsiam

"""SimSiam: Simple Siamese self-supervised learning.

Two-view siamese network with a stop-gradient on the target branch — no
momentum encoder, no negatives, no large batches. The student has a
backbone, projector, and predictor; the loss is the negative cosine
similarity between the predictor output and a stop-gradient projection
of the other view.

References:
    Chen, He. "Exploring Simple Siamese Representation Learning." CVPR 2021.
    https://arxiv.org/abs/2011.10566
"""

from dataclasses import dataclass
from typing import Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.utils import ModelOutput

from stable_pretraining import Module
from stable_pretraining.backbone import from_timm


@dataclass
class SimSiamOutput(ModelOutput):
    """Structured output of the :class:`SimSiam` SSL method."""

    loss: torch.Tensor = None
    embedding: torch.Tensor = None
    prediction: Optional[torch.Tensor] = None
    target: Optional[torch.Tensor] = None


def _projector(in_dim: int, hidden_dim: int, out_dim: int) -> nn.Module:
    """SimSiam projector: 3-layer MLP with BN, no final ReLU."""
    return nn.Sequential(
        nn.Linear(in_dim, hidden_dim, bias=False),
        nn.BatchNorm1d(hidden_dim),
        nn.ReLU(inplace=True),
        nn.Linear(hidden_dim, hidden_dim, bias=False),
        nn.BatchNorm1d(hidden_dim),
        nn.ReLU(inplace=True),
        nn.Linear(hidden_dim, out_dim, bias=False),
        nn.BatchNorm1d(out_dim, affine=False),
    )


def _predictor(dim: int, hidden_dim: int) -> nn.Module:
    """SimSiam predictor: 2-layer MLP with bottleneck."""
    return nn.Sequential(
        nn.Linear(dim, hidden_dim, bias=False),
        nn.BatchNorm1d(hidden_dim),
        nn.ReLU(inplace=True),
        nn.Linear(hidden_dim, dim),
    )


def _neg_cos_sim(p: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
    """Symmetric negative cosine similarity loss with stop-gradient on z."""
    z = z.detach()
    p = F.normalize(p, dim=-1)
    z = F.normalize(z, dim=-1)
    return -(p * z).sum(dim=-1).mean()


[docs] class SimSiam(Module): """SimSiam: simple siamese SSL with stop-gradient. :param encoder_name: timm model name or pre-built ``nn.Module``. :param projector_dim: Projector hidden + output dim (default 2048). :param predictor_hidden_dim: Predictor bottleneck dim (default 512). :param low_resolution: Adapt first conv for 32x32. :param pretrained: Load pretrained timm weights. """ def __init__( self, encoder_name: Union[str, nn.Module] = "vit_small_patch16_224", projector_dim: int = 2048, predictor_hidden_dim: int = 512, low_resolution: bool = False, pretrained: bool = False, ): super().__init__() if isinstance(encoder_name, str): self.backbone = from_timm( encoder_name, num_classes=0, low_resolution=low_resolution, pretrained=pretrained, ) else: self.backbone = encoder_name with torch.no_grad(): embed_dim = self.backbone(torch.zeros(1, 3, 224, 224)).shape[-1] self.embed_dim = embed_dim self.projector = _projector(embed_dim, projector_dim, projector_dim) self.predictor = _predictor(projector_dim, predictor_hidden_dim)
[docs] def forward( self, view1: torch.Tensor, view2: Optional[torch.Tensor] = None, ) -> SimSiamOutput: if view2 is None: embedding = self.backbone(view1) return SimSiamOutput( loss=torch.zeros((), device=embedding.device, dtype=embedding.dtype), embedding=embedding, prediction=None, target=None, ) h1 = self.backbone(view1) h2 = self.backbone(view2) z1 = self.projector(h1) z2 = self.projector(h2) p1 = self.predictor(z1) p2 = self.predictor(z2) loss = (_neg_cos_sim(p1, z2) + _neg_cos_sim(p2, z1)) / 2 return SimSiamOutput( loss=loss, embedding=torch.cat([h1, h2], dim=0), prediction=torch.cat([p1, p2], dim=0), target=torch.cat([z1, z2], dim=0).detach(), )