Source code for stable_pretraining.methods.mocov2

"""MoCo v2: Momentum Contrast (CNN-style with queue).

Memory-bank contrastive: a momentum encoder produces "key" features that
are pushed onto a FIFO queue; the student "query" is contrasted (InfoNCE)
against the latest queue. Distinct from MoCo v3 (no queue, ViT-tuned).

References:
    Chen et al. "Improved Baselines with Momentum Contrastive Learning."
    arXiv 2020. https://arxiv.org/abs/2003.04297
"""

from dataclasses import dataclass
from typing import Optional, Sequence, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.utils import ModelOutput

from stable_pretraining import Module
from stable_pretraining.backbone import TeacherStudentWrapper, from_timm
from stable_pretraining.callbacks.queues import UnsortedQueue


@dataclass
class MoCov2Output(ModelOutput):
    """Structured output of the :class:`MoCov2` SSL method."""

    loss: torch.Tensor = None
    embedding: torch.Tensor = None
    queries: Optional[torch.Tensor] = None
    keys: Optional[torch.Tensor] = None


def _projector(in_dim: int, hidden_dim: int, out_dim: int) -> nn.Module:
    return nn.Sequential(
        nn.Linear(in_dim, hidden_dim),
        nn.ReLU(inplace=True),
        nn.Linear(hidden_dim, out_dim),
    )


[docs] class MoCov2(Module): """MoCo v2 with a fixed-size FIFO queue of momentum-encoder keys. :param encoder_name: timm model name or pre-built ``nn.Module``. :param projector_dims: ``(hidden, output)`` for the 2-layer head (default ``(2048, 128)``; matches MoCo v2's ResNet50 recipe). :param queue_length: FIFO key queue size (default 65536). :param temperature: InfoNCE temperature (default 0.2). :param ema_decay_start: Initial momentum (default 0.999, paper). :param ema_decay_end: Final momentum (default 0.999). :param low_resolution: Adapt first conv for low-res input. :param pretrained: Load pretrained timm weights. """ def __init__( self, encoder_name: Union[str, nn.Module] = "vit_small_patch16_224", projector_dims: Sequence[int] = (2048, 128), queue_length: int = 65536, temperature: float = 0.2, ema_decay_start: float = 0.999, ema_decay_end: float = 0.999, low_resolution: bool = False, pretrained: bool = False, ): super().__init__() if isinstance(encoder_name, str): base = from_timm( encoder_name, num_classes=0, low_resolution=low_resolution, pretrained=pretrained, ) else: base = encoder_name with torch.no_grad(): embed_dim = base(torch.zeros(1, 3, 224, 224)).shape[-1] self.embed_dim = embed_dim self.temperature = temperature proj_hidden, proj_out = projector_dims self.backbone = TeacherStudentWrapper( base, warm_init=True, base_ema_coefficient=ema_decay_start, final_ema_coefficient=ema_decay_end, ) self.projector = TeacherStudentWrapper( _projector(embed_dim, proj_hidden, proj_out), warm_init=True, base_ema_coefficient=ema_decay_start, final_ema_coefficient=ema_decay_end, ) self.queue = UnsortedQueue( max_length=queue_length, shape=(proj_out,), dtype=torch.float32 )
[docs] def forward( self, view1: torch.Tensor, view2: Optional[torch.Tensor] = None, ) -> MoCov2Output: if view2 is None: with torch.no_grad(): emb = self.backbone.forward_teacher(view1) return MoCov2Output( loss=torch.zeros((), device=emb.device, dtype=emb.dtype), embedding=emb.detach(), ) # Query (student) on view1 q = self.projector.forward_student(self.backbone.forward_student(view1)) q = F.normalize(q, dim=-1) # Key (teacher / momentum) on view2 — no grad with torch.no_grad(): k = self.projector.forward_teacher(self.backbone.forward_teacher(view2)) k = F.normalize(k, dim=-1) queue_keys = self.queue.append(k.detach().to(torch.float32)).to(q.dtype) # InfoNCE: positive = matched key, negatives = queue logits_pos = (q * k).sum(dim=-1, keepdim=True) # [B, 1] logits_neg = q @ queue_keys.T # [B, K] logits = torch.cat([logits_pos, logits_neg], dim=1) / self.temperature targets = torch.zeros(q.shape[0], dtype=torch.long, device=q.device) loss = F.cross_entropy(logits, targets) with torch.no_grad(): embedding = self.backbone.forward_teacher(view1).detach() return MoCov2Output( loss=loss, embedding=embedding, queries=q, keys=k, )