Source code for stable_pretraining.methods.vicregl

"""VICRegL: VICReg + local feature matching.

Adds a local term to VICReg by matching the *most similar* spatial
location across two views (or by L2-distance on coordinate maps). The
global term is the same as VICReg (variance / invariance / covariance);
the local term applies the same VICReg objective on patch tokens after a
nearest-neighbour assignment.

References:
    Bardes, Ponce, LeCun. "VICRegL: Self-Supervised Learning of Local
    Visual Features." NeurIPS 2022. https://arxiv.org/abs/2210.01571
"""

from dataclasses import dataclass
from typing import Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.utils import ModelOutput

from stable_pretraining import Module
from stable_pretraining.losses import VICRegLoss


@dataclass
class VICRegLOutput(ModelOutput):
    """Structured output of the :class:`VICRegL` SSL method."""

    loss: torch.Tensor = None
    loss_global: torch.Tensor = None
    loss_local: torch.Tensor = None
    embedding: torch.Tensor = None


def _projector(in_dim: int, hidden_dim: int, out_dim: int) -> nn.Module:
    return nn.Sequential(
        nn.Linear(in_dim, hidden_dim),
        nn.BatchNorm1d(hidden_dim),
        nn.ReLU(inplace=True),
        nn.Linear(hidden_dim, hidden_dim),
        nn.BatchNorm1d(hidden_dim),
        nn.ReLU(inplace=True),
        nn.Linear(hidden_dim, out_dim, bias=False),
    )


[docs] class VICRegL(Module): """VICRegL: VICReg with an extra local-feature term. :param encoder_name: timm ViT name (default ``"vit_small_patch16_224"``). :param projector_dim: Output dim of both global and local projectors (default 2048). :param sim_coeff: Invariance weight (default 25.0). :param std_coeff: Variance weight (default 25.0). :param cov_coeff: Covariance weight (default 1.0). :param alpha: Mixing weight between global and local terms (default 0.75 means global gets 75%). :param image_size: Input size (default 224). :param pretrained: Load pretrained timm weights. """ def __init__( self, encoder_name: Union[str, nn.Module] = "vit_small_patch16_224", projector_dim: int = 2048, sim_coeff: float = 25.0, std_coeff: float = 25.0, cov_coeff: float = 1.0, alpha: float = 0.75, image_size: int = 224, pretrained: bool = False, ): super().__init__() if isinstance(encoder_name, str): import timm self.encoder = timm.create_model( encoder_name, num_classes=0, pretrained=pretrained ) else: self.encoder = encoder_name with torch.no_grad(): seq = self.encoder.forward_features( torch.zeros(1, 3, image_size, image_size) ) self._has_cls = ( hasattr(self.encoder, "cls_token") and self.encoder.cls_token is not None and seq.shape[1] > 1 ) embed_dim = seq.shape[-1] self.embed_dim = embed_dim self.alpha = alpha self.projector_global = _projector(embed_dim, projector_dim, projector_dim) self.projector_local = _projector(embed_dim, projector_dim, projector_dim) self.vicreg_loss = VICRegLoss( sim_coeff=sim_coeff, std_coeff=std_coeff, cov_coeff=cov_coeff ) def _split(self, features: torch.Tensor): """Return (cls, patches) from the encoder's forward_features.""" if self._has_cls: return features[:, 0], features[:, 1:] return features.mean(dim=1), features
[docs] def forward( self, view1: torch.Tensor, view2: Optional[torch.Tensor] = None, ) -> VICRegLOutput: if view2 is None: features = self.encoder.forward_features(view1) cls, _ = self._split(features) return VICRegLOutput( loss=torch.zeros((), device=view1.device, dtype=view1.dtype), embedding=cls, ) # Global features (CLS) and patch tokens f1 = self.encoder.forward_features(view1) f2 = self.encoder.forward_features(view2) cls1, p1 = self._split(f1) cls2, p2 = self._split(f2) z1 = self.projector_global(cls1) z2 = self.projector_global(cls2) loss_global = self.vicreg_loss(z1, z2) # Local: project patch tokens, then match each patch in view1 to its # nearest neighbour in view2 (cosine similarity) and apply VICReg on # the matched pairs. B, N, D = p1.shape zl1 = self.projector_local(p1.reshape(B * N, D)).view(B, N, -1) zl2 = self.projector_local(p2.reshape(B * N, D)).view(B, N, -1) # Normalise, find nearest neighbour per (image, patch). nzl1 = F.normalize(zl1, dim=-1) nzl2 = F.normalize(zl2, dim=-1) sim = nzl1 @ nzl2.transpose(1, 2) # [B, N, N] nn_idx = sim.argmax(dim=-1) # [B, N] idx = nn_idx.unsqueeze(-1).expand(-1, -1, zl2.shape[-1]) zl2_aligned = torch.gather(zl2, dim=1, index=idx) loss_local = self.vicreg_loss( zl1.reshape(B * N, -1), zl2_aligned.reshape(B * N, -1) ) loss = self.alpha * loss_global + (1 - self.alpha) * loss_local return VICRegLOutput( loss=loss, loss_global=loss_global.detach(), loss_local=loss_local.detach(), embedding=torch.cat([cls1, cls2], dim=0), )