Source code for stable_pretraining.methods.msn

"""MSN: Masked Siamese Networks.

DINO-style prototype matching where the student sees a *randomly-masked*
view of the image and the teacher sees the unmasked view. Uses
Sinkhorn-Knopp on the teacher's soft assignments to enforce a balanced
prototype distribution; also adds a mean-entropy regulariser to prevent
trivial solutions.

References:
    Assran et al. "Masked Siamese Networks for Label-Efficient Learning."
    ECCV 2022. https://arxiv.org/abs/2204.07141
"""

from dataclasses import dataclass
from typing import Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.utils import ModelOutput

from stable_pretraining import Module
from stable_pretraining.backbone import L2Norm, TeacherStudentWrapper
from stable_pretraining.losses.utils import sinkhorn_knopp


@dataclass
class MSNOutput(ModelOutput):
    """Structured output of the :class:`MSN` SSL method."""

    loss: torch.Tensor = None
    embedding: torch.Tensor = None
    student_logits: Optional[torch.Tensor] = None
    teacher_logits: Optional[torch.Tensor] = None


def _msn_head(
    in_dim: int, hidden_dim: int, bottleneck_dim: int, n_prototypes: int
) -> nn.Module:
    return nn.Sequential(
        nn.Linear(in_dim, hidden_dim),
        nn.GELU(),
        nn.Linear(hidden_dim, hidden_dim),
        nn.GELU(),
        nn.Linear(hidden_dim, bottleneck_dim),
        L2Norm(),
        nn.Linear(bottleneck_dim, n_prototypes, bias=False),
    )


def _to_cls(features: torch.Tensor) -> torch.Tensor:
    if features.ndim == 2:
        return features
    if features.ndim == 3:
        return features[:, 0]
    raise ValueError(f"Unexpected backbone output shape {tuple(features.shape)}")


[docs] class MSN(Module): """MSN: masked siamese DINO-style SSL. :param encoder_name: timm ViT name (default ``"vit_small_patch16_224"``). :param projector_hidden_dim: Hidden dim (default 2048). :param projector_bottleneck_dim: Bottleneck dim (default 256). :param n_prototypes: Prototype count (default 1024). :param mask_ratio: Patch mask ratio for the *student* (default 0.6). :param temperature_student: Student softmax temperature (default 0.1). :param temperature_teacher: Teacher softmax temperature (default 0.025). :param me_max_weight: Mean-entropy maximisation weight (default 1.0). :param ema_decay_start: Initial backbone/head EMA (default 0.996). :param ema_decay_end: Final EMA (default 1.0). :param image_size: Input size (default 224). :param pretrained: Load pretrained timm weights. """ def __init__( self, encoder_name: Union[str, nn.Module] = "vit_small_patch16_224", projector_hidden_dim: int = 2048, projector_bottleneck_dim: int = 256, n_prototypes: int = 1024, mask_ratio: float = 0.6, temperature_student: float = 0.1, temperature_teacher: float = 0.025, me_max_weight: float = 1.0, ema_decay_start: float = 0.996, ema_decay_end: float = 1.0, image_size: int = 224, pretrained: bool = False, ): super().__init__() if isinstance(encoder_name, str): import timm base = timm.create_model(encoder_name, num_classes=0, pretrained=pretrained) else: base = encoder_name with torch.no_grad(): seq = base.forward_features(torch.zeros(1, 3, image_size, image_size)) self._has_cls = ( hasattr(base, "cls_token") and base.cls_token is not None and seq.shape[1] > 1 ) embed_dim = seq.shape[-1] self.embed_dim = embed_dim self.mask_ratio = mask_ratio self.temperature_student = temperature_student self.temperature_teacher = temperature_teacher self.me_max_weight = me_max_weight self.image_size = image_size self.backbone = TeacherStudentWrapper( base, warm_init=True, base_ema_coefficient=ema_decay_start, final_ema_coefficient=ema_decay_end, ) self.projector = TeacherStudentWrapper( _msn_head( embed_dim, projector_hidden_dim, projector_bottleneck_dim, n_prototypes ), warm_init=True, base_ema_coefficient=ema_decay_start, final_ema_coefficient=ema_decay_end, ) self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) nn.init.trunc_normal_(self.mask_token, std=0.02) def _encode(self, vit, images, mask=None): x = vit.patch_embed(images) if mask is not None: m = mask.unsqueeze(-1) x = x * (1 - m) + self.mask_token.expand_as(x) * m if self._has_cls: cls = vit.cls_token.expand(x.shape[0], -1, -1) x = torch.cat([cls, x], dim=1) x = x + vit.pos_embed x = vit.pos_drop(x) x = vit.blocks(x) return vit.norm(x) def _random_mask(self, B, N, device): n_mask = int(round(N * self.mask_ratio)) noise = torch.rand(B, N, device=device) order = noise.argsort(dim=1) mask = torch.zeros(B, N, device=device) mask.scatter_(1, order[:, :n_mask], 1.0) return mask
[docs] def forward( self, view1: Optional[torch.Tensor] = None, view2: Optional[torch.Tensor] = None, images: Optional[torch.Tensor] = None, ) -> MSNOutput: # Eval / single-image if images is not None or (view1 is not None and view2 is None): single = images if images is not None else view1 with torch.no_grad(): feats = self.backbone.forward_teacher(single) cls = _to_cls(feats) return MSNOutput( loss=torch.zeros((), device=cls.device, dtype=cls.dtype), embedding=cls.detach(), ) if view1 is None or view2 is None: raise ValueError("MSN.forward needs view1 and view2 (or images for eval)") # Teacher: unmasked view2 with torch.no_grad(): t_feats = self._encode(self.backbone.teacher, view2, mask=None) t_cls = _to_cls(t_feats) t_logits = self.projector.forward_teacher(t_cls) # Student: masked view1 B = view1.shape[0] with torch.no_grad(): n_patches = self.backbone.student.patch_embed(view1[:1]).shape[1] mask = self._random_mask(B, n_patches, device=view1.device) s_feats = self._encode(self.backbone.student, view1, mask=mask) s_cls = _to_cls(s_feats) s_logits = self.projector.forward_student(s_cls) # Sinkhorn-Knopp on the teacher targets to enforce a balanced # prototype distribution; cross-entropy student vs teacher. teacher_probs = sinkhorn_knopp( teacher_output=t_logits, teacher_temp=self.temperature_teacher, num_samples=t_logits.shape[0], ) student_log_probs = F.log_softmax(s_logits / self.temperature_student, dim=-1) ce = -(teacher_probs * student_log_probs).sum(dim=-1).mean() # Mean-entropy maximisation: prevent collapse to a single prototype # by maximising the entropy of the *average* student distribution. mean_p = F.softmax(s_logits / self.temperature_student, dim=-1).mean(dim=0) me_max = (mean_p * mean_p.clamp(min=1e-8).log()).sum() loss = ce + self.me_max_weight * me_max return MSNOutput( loss=loss, embedding=t_cls.detach(), student_logits=s_logits, teacher_logits=t_logits.detach(), )