Source code for stable_pretraining.methods.vicreg

"""VICReg: Variance-Invariance-Covariance Regularization.

Self-supervised learning by enforcing three criteria on the projected
embeddings of two augmented views:
- **Invariance** to augmentations (MSE between views)
- **Variance** preservation (per-dimension std hinge loss)
- **Covariance** decorrelation (off-diagonal cross-covariance penalty)

References:
    Bardes, Ponce, LeCun. "VICReg: Variance-Invariance-Covariance
    Regularization for Self-Supervised Learning." ICLR 2022.
    https://arxiv.org/abs/2105.04906
"""

from dataclasses import dataclass
from typing import Optional, Sequence, Union

import torch
import torch.nn as nn
from transformers.utils import ModelOutput

from stable_pretraining import Module
from stable_pretraining.backbone import from_timm
from stable_pretraining.losses import VICRegLoss


@dataclass
class VICRegOutput(ModelOutput):
    """Output from VICReg forward pass.

    :ivar loss: VICReg loss (0 in eval mode)
    :ivar embedding: Backbone features [B, D] (eval) or [2B, D] (train)
    :ivar projection: Projector outputs [2B, P] (None in eval)
    """

    loss: torch.Tensor = None
    embedding: torch.Tensor = None
    projection: Optional[torch.Tensor] = None


def _build_vicreg_projector(in_dim: int, hidden_dims: Sequence[int]) -> nn.Module:
    """3-layer Linear-BN-ReLU projector ending in a bias-free Linear.

    The original VICReg recipe uses (8192, 8192, 8192) with no final BN.
    """
    if len(hidden_dims) < 1:
        raise ValueError("hidden_dims must contain at least one entry")
    layers = []
    prev = in_dim
    for i, dim in enumerate(hidden_dims):
        is_last = i == len(hidden_dims) - 1
        if is_last:
            layers.append(nn.Linear(prev, dim, bias=False))
        else:
            layers.append(nn.Linear(prev, dim))
            layers.append(nn.BatchNorm1d(dim))
            layers.append(nn.ReLU(inplace=True))
        prev = dim
    return nn.Sequential(*layers)


[docs] class VICReg(Module): """VICReg: variance-invariance-covariance self-supervised learning. :param encoder_name: timm model name or pre-built ``nn.Module``. :param projector_dims: Hidden + output dims for the projector. Default ``(8192, 8192, 8192)`` matches the ResNet50 paper recipe. :param sim_coeff: Invariance term weight (default 25.0). :param std_coeff: Variance term weight (default 25.0). :param cov_coeff: Covariance term weight (default 1.0). :param low_resolution: Adapt first conv for 32x32 inputs (CIFAR-style). :param pretrained: Load pretrained timm weights for the encoder. """ def __init__( self, encoder_name: Union[str, nn.Module] = "vit_small_patch16_224", projector_dims: Sequence[int] = (8192, 8192, 8192), sim_coeff: float = 25.0, std_coeff: float = 25.0, cov_coeff: float = 1.0, low_resolution: bool = False, pretrained: bool = False, ): super().__init__() if isinstance(encoder_name, str): self.backbone = from_timm( encoder_name, num_classes=0, low_resolution=low_resolution, pretrained=pretrained, ) else: self.backbone = encoder_name with torch.no_grad(): embed_dim = self.backbone(torch.zeros(1, 3, 224, 224)).shape[-1] self.embed_dim = embed_dim self.projector = _build_vicreg_projector(embed_dim, list(projector_dims)) self.vicreg_loss = VICRegLoss( sim_coeff=sim_coeff, std_coeff=std_coeff, cov_coeff=cov_coeff )
[docs] def forward( self, view1: torch.Tensor, view2: Optional[torch.Tensor] = None, ) -> VICRegOutput: if view2 is None: embedding = self.backbone(view1) return VICRegOutput( loss=torch.zeros((), device=embedding.device, dtype=embedding.dtype), embedding=embedding, projection=None, ) h1 = self.backbone(view1) h2 = self.backbone(view2) z1 = self.projector(h1) z2 = self.projector(h2) loss = self.vicreg_loss(z1, z2) return VICRegOutput( loss=loss, embedding=torch.cat([h1, h2], dim=0), projection=torch.cat([z1, z2], dim=0), )