Source code for stable_pretraining.methods.dinov2

"""DINOv2: scaled DINO + iBOT with Sinkhorn-Knopp normalisation.

DINOv2 builds on DINO and iBOT by:
- Replacing classical centering with Sinkhorn-Knopp on both CLS and patch
  prototype distributions.
- Using KoLeo regularisation (optional, omitted here for simplicity) to
  spread out features.
- Larger schedules and registers (also omitted; can be set via
  ``encoder_kwargs={"global_pool": "token", "reg_tokens": 4}`` on
  recent timm versions).

This implementation reuses :class:`iBOT` for the architecture and swaps the
loss to use Sinkhorn-Knopp via :class:`DINOv2Loss`.

References:
    Oquab et al. "DINOv2: Learning Robust Visual Features without
    Supervision." TMLR 2024. https://arxiv.org/abs/2304.07193
"""

from dataclasses import dataclass
from typing import Optional, Sequence, Union

import torch
import torch.nn as nn
from transformers.utils import ModelOutput

from stable_pretraining import Module
from stable_pretraining.backbone import L2Norm, TeacherStudentWrapper
from stable_pretraining.losses import DINOv2Loss


@dataclass
class DINOv2Output(ModelOutput):
    """Structured output of the :class:`DINOv2` SSL method."""

    loss: torch.Tensor = None
    loss_cls: torch.Tensor = None
    loss_patch: torch.Tensor = None
    embedding: torch.Tensor = None


def _ibot_head(
    in_dim: int, hidden_dim: int, bottleneck_dim: int, n_prototypes: int
) -> nn.Module:
    return nn.Sequential(
        nn.Linear(in_dim, hidden_dim),
        nn.GELU(),
        nn.Linear(hidden_dim, hidden_dim),
        nn.GELU(),
        nn.Linear(hidden_dim, bottleneck_dim),
        L2Norm(),
        nn.Linear(bottleneck_dim, n_prototypes, bias=False),
    )


def _split_cls_patches(features: torch.Tensor, has_cls: bool):
    if features.ndim == 2:
        # Eval path: timm ViT pooled output is already CLS-like.
        return features, None
    if has_cls:
        return features[:, 0], features[:, 1:]
    return features.mean(dim=1), features


[docs] class DINOv2(Module): """DINOv2: DINO + iBOT with Sinkhorn-Knopp on CLS and patch prototypes. :param encoder_name: timm ViT name (default ``"vit_small_patch16_224"``). :param projector_hidden_dim: Hidden dim for both heads (default 2048). :param projector_bottleneck_dim: Bottleneck dim (default 256). :param n_cls_prototypes: CLS prototypes (default 65536). :param n_patch_prototypes: Patch prototypes (default 8192). :param mask_ratio: Patch mask ratio for the student (default 0.3). :param patch_loss_weight: Weight on the patch loss (default 1.0). :param temperature_student: Student softmax temperature (default 0.1). :param temperature_teacher: Teacher temperature (default 0.07). :param ema_decay_start: Initial backbone/head EMA (default 0.996). :param ema_decay_end: Final EMA (default 1.0). :param image_size: Input size (default 224). :param pretrained: Load pretrained timm weights. """ def __init__( self, encoder_name: Union[str, nn.Module] = "vit_small_patch16_224", projector_hidden_dim: int = 2048, projector_bottleneck_dim: int = 256, n_cls_prototypes: int = 65536, n_patch_prototypes: int = 8192, mask_ratio: float = 0.3, patch_loss_weight: float = 1.0, temperature_student: float = 0.1, temperature_teacher: float = 0.07, ema_decay_start: float = 0.996, ema_decay_end: float = 1.0, image_size: int = 224, pretrained: bool = False, ): super().__init__() if isinstance(encoder_name, str): import timm base = timm.create_model( encoder_name, num_classes=0, pretrained=pretrained, dynamic_img_size=True, # support multi-crop (224 + 96 etc.) ) else: base = encoder_name with torch.no_grad(): seq = base.forward_features(torch.zeros(1, 3, image_size, image_size)) self._has_cls = ( hasattr(base, "cls_token") and base.cls_token is not None and seq.shape[1] > 1 ) embed_dim = seq.shape[-1] self.embed_dim = embed_dim self.mask_ratio = mask_ratio self.patch_loss_weight = patch_loss_weight self.temperature_teacher = temperature_teacher self.image_size = image_size self.backbone = TeacherStudentWrapper( base, warm_init=True, base_ema_coefficient=ema_decay_start, final_ema_coefficient=ema_decay_end, ) self.cls_head = TeacherStudentWrapper( _ibot_head( embed_dim, projector_hidden_dim, projector_bottleneck_dim, n_cls_prototypes, ), warm_init=True, base_ema_coefficient=ema_decay_start, final_ema_coefficient=ema_decay_end, ) self.patch_head = TeacherStudentWrapper( _ibot_head( embed_dim, projector_hidden_dim, projector_bottleneck_dim, n_patch_prototypes, ), warm_init=True, base_ema_coefficient=ema_decay_start, final_ema_coefficient=ema_decay_end, ) self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) nn.init.trunc_normal_(self.mask_token, std=0.02) self.dinov2_loss = DINOv2Loss(student_temp=temperature_student) def _encode(self, vit, images, mask=None): x = vit.patch_embed(images) # With ``dynamic_img_size=True`` patch_embed returns 4D [B, H', W', D] # and ``_pos_embed`` *requires* 4D. For the mask substitution we # reshape to 3D, apply the mask, then reshape back. is_4d = x.ndim == 4 if mask is not None: if is_4d: B_, H_, W_, D_ = x.shape x = x.reshape(B_, H_ * W_, D_) m = mask.unsqueeze(-1) x = x * (1 - m) + self.mask_token.expand_as(x) * m if is_4d: x = x.reshape(B_, H_, W_, D_) x = vit._pos_embed(x) if hasattr(vit, "patch_drop"): x = vit.patch_drop(x) if hasattr(vit, "norm_pre"): x = vit.norm_pre(x) x = vit.blocks(x) return vit.norm(x) def _random_mask(self, B, N, device): n_mask = int(round(N * self.mask_ratio)) noise = torch.rand(B, N, device=device) order = noise.argsort(dim=1) mask = torch.zeros(B, N, device=device) mask.scatter_(1, order[:, :n_mask], 1.0) return mask
[docs] def forward( self, global_views: Optional[Sequence[torch.Tensor]] = None, local_views: Optional[Sequence[torch.Tensor]] = None, images: Optional[torch.Tensor] = None, ) -> DINOv2Output: if images is not None: with torch.no_grad(): feats = self.backbone.forward_teacher(images) cls, _ = _split_cls_patches(feats, self._has_cls) return DINOv2Output( loss=torch.zeros((), device=cls.device, dtype=cls.dtype), embedding=cls.detach(), ) if not global_views: raise ValueError("DINOv2.forward needs global_views or images") global_views = list(global_views) local_views = list(local_views or []) n_global = len(global_views) n_local = len(local_views) global_imgs = torch.cat(global_views, dim=0) B = global_views[0].shape[0] with torch.no_grad(): pe = self.backbone.student.patch_embed(global_imgs[:1]) # ``dynamic_img_size=True`` returns [B, H', W', D]; flat returns [B, N, D]. n_patches = pe.shape[1] * pe.shape[2] if pe.ndim == 4 else pe.shape[1] mask = self._random_mask( global_imgs.shape[0], n_patches, device=global_imgs.device ) # Teacher: globals only, unmasked. Provides Sinkhorn targets for # both CLS and (masked) patches. with torch.no_grad(): t_feats = self._encode(self.backbone.teacher, global_imgs, mask=None) t_cls, t_patches = _split_cls_patches(t_feats, self._has_cls) t_cls_logits = self.cls_head.forward_teacher(t_cls).view(n_global, B, -1) t_patch_logits = self.patch_head.forward_teacher(t_patches.flatten(0, 1)) t_patch_logits = t_patch_logits.view( t_patches.shape[0], t_patches.shape[1], -1 ) # Student: globals (with patch mask) → CLS + patch logits. s_feats_g = self._encode(self.backbone.student, global_imgs, mask=mask) s_cls_g, s_patches_g = _split_cls_patches(s_feats_g, self._has_cls) s_cls_logits_g = self.cls_head.forward_student(s_cls_g).view(n_global, B, -1) s_patch_logits = self.patch_head.forward_student(s_patches_g.flatten(0, 1)) s_patch_logits = s_patch_logits.view( s_patches_g.shape[0], s_patches_g.shape[1], -1 ) # Student: locals (no mask). Locals contribute *only* to the CLS # loss — they have a smaller spatial extent so iBOT-style patch # supervision doesn't apply across resolutions (matches paper). if n_local > 0: local_imgs = torch.cat(local_views, dim=0) s_feats_l = self._encode(self.backbone.student, local_imgs, mask=None) s_cls_l, _ = _split_cls_patches(s_feats_l, self._has_cls) s_cls_logits_l = self.cls_head.forward_student(s_cls_l).view(n_local, B, -1) s_cls_logits = torch.cat([s_cls_logits_g, s_cls_logits_l], dim=0) else: s_cls_logits = s_cls_logits_g # Sinkhorn-Knopp on CLS targets n_cls = t_cls_logits.numel() // t_cls_logits.shape[-1] teacher_cls_probs = self.dinov2_loss.dino_loss.sinkhorn_knopp_teacher( t_cls_logits, teacher_temp=self.temperature_teacher, num_samples=n_cls ) loss_cls = self.dinov2_loss.dino_loss(s_cls_logits, teacher_cls_probs) mask_flat = mask.bool().view(-1) s_patch_flat = s_patch_logits.reshape(-1, s_patch_logits.shape[-1])[mask_flat] t_patch_flat = t_patch_logits.reshape(-1, t_patch_logits.shape[-1])[mask_flat] teacher_patch_probs = self.dinov2_loss.ibot_loss.sinkhorn_knopp_teacher( t_patch_flat, teacher_temp=self.temperature_teacher, num_samples=t_patch_flat.shape[0], ) loss_patch = self.dinov2_loss.ibot_loss(s_patch_flat, teacher_patch_probs) loss = loss_cls + self.patch_loss_weight * loss_patch return DINOv2Output( loss=loss, loss_cls=loss_cls.detach(), loss_patch=loss_patch.detach(), embedding=t_cls.detach(), )