Source code for stable_pretraining.methods.dino

"""DINO: Self-distillation with no labels.

Self-supervised learning by distilling a teacher (EMA of student) into a
student that processes a richer view of the data (multi-crop). The student
matches the teacher's softmaxed prototype assignments via cross-entropy.

References:
    Caron et al. "Emerging Properties in Self-Supervised Vision Transformers."
    ICCV 2021. https://arxiv.org/abs/2104.14294
"""

from dataclasses import dataclass
from typing import List, Optional, Sequence, Union

import torch
import torch.nn as nn
from transformers.utils import ModelOutput

from stable_pretraining import Module
from stable_pretraining.backbone import L2Norm, TeacherStudentWrapper
from stable_pretraining.losses import DINOv1Loss


@dataclass
class DINOOutput(ModelOutput):
    """Output from DINO forward pass.

    :ivar loss: DINO cross-entropy loss (0 in eval mode)
    :ivar embedding: Teacher CLS features [B, D] (eval) or [n_global * B, D] (train)
    :ivar teacher_logits: Teacher prototype logits [n_global, B, K] (None in eval)
    :ivar student_logits: Student prototype logits [n_views, B, K] (None in eval)
    """

    loss: torch.Tensor = None
    embedding: torch.Tensor = None
    teacher_logits: Optional[torch.Tensor] = None
    student_logits: Optional[torch.Tensor] = None


def _build_dino_projector(
    in_dim: int, hidden_dim: int, bottleneck_dim: int, n_prototypes: int
) -> nn.Module:
    """Standard DINO projector: 3-layer MLP + L2 norm + linear prototypes.

    The prototypes layer is bias-free; weight-norm is applied via the L2Norm
    on the bottleneck (the original DINO uses ``nn.utils.weight_norm`` on the
    prototypes Linear instead — equivalent up to a learnable scale that is
    typically frozen).
    """
    return nn.Sequential(
        nn.Linear(in_dim, hidden_dim),
        nn.GELU(),
        nn.Linear(hidden_dim, hidden_dim),
        nn.GELU(),
        nn.Linear(hidden_dim, bottleneck_dim),
        L2Norm(),
        nn.Linear(bottleneck_dim, n_prototypes, bias=False),
    )


def _to_cls(features: torch.Tensor) -> torch.Tensor:
    """Reduce backbone output to a [B, D] CLS-like representation.

    timm ViTs with ``num_classes=0`` already pool to ``[B, D]`` (default
    is class-token pooling). For backbones that return a sequence
    ``[B, N, D]``, take the first token.
    """
    if features.ndim == 2:
        return features
    if features.ndim == 3:
        return features[:, 0]
    raise ValueError(f"Unexpected backbone output shape {tuple(features.shape)}")


[docs] class DINO(Module): """DINO self-distillation with multi-crop and an EMA teacher. Architecture: - **Backbone** (student) wrapped in :class:`TeacherStudentWrapper` (teacher is an EMA copy). - **Projector** (student) wrapped in :class:`TeacherStudentWrapper`: 3-layer MLP -> L2-norm -> linear prototypes (default 65k). - **Loss**: :class:`DINOv1Loss` with classical centering. The teacher only sees global crops; the student sees both global and local crops. Loss is the average pairwise cross-entropy between every student view and every teacher view (excluding same-view pairs handled inside :class:`DINOv1Loss`). :param encoder_name: timm model name (default ``"vit_small_patch16_224"``) or pre-built ``nn.Module``. For multi-crop, the backbone must accept variable input sizes; pass ``dynamic_img_size=True`` via ``encoder_kwargs`` for timm ViTs. :param projector_hidden_dim: Hidden dim of the 3-layer MLP (default 2048). :param projector_bottleneck_dim: Bottleneck dim before prototypes (default 256). :param n_prototypes: Number of prototypes / output dim (default 65536). :param temperature_student: Student softmax temperature (default 0.1). :param temperature_teacher_warmup: Teacher temp at start (default 0.04). :param temperature_teacher: Teacher temp after warmup (default 0.07). :param warmup_epochs_temperature_teacher: Linear warmup epochs (default 30). :param center_momentum: EMA momentum for the teacher centering (default 0.9). :param ema_decay_start: Initial backbone/projector EMA (default 0.996). :param ema_decay_end: Final EMA (default 1.0). :param encoder_kwargs: Extra kwargs forwarded to ``timm.create_model``. :param pretrained: Load pretrained timm weights for the encoder. Example:: model = DINO("vit_small_patch16_224", encoder_kwargs={"dynamic_img_size": True}) global_views = [torch.randn(8, 3, 224, 224), torch.randn(8, 3, 224, 224)] local_views = [torch.randn(8, 3, 96, 96) for _ in range(6)] out = model(global_views=global_views, local_views=local_views) out.loss.backward() """ def __init__( self, encoder_name: Union[str, nn.Module] = "vit_small_patch16_224", projector_hidden_dim: int = 2048, projector_bottleneck_dim: int = 256, n_prototypes: int = 65536, temperature_student: float = 0.1, temperature_teacher_warmup: float = 0.04, temperature_teacher: float = 0.07, warmup_epochs_temperature_teacher: int = 30, center_momentum: float = 0.9, ema_decay_start: float = 0.996, ema_decay_end: float = 1.0, encoder_kwargs: Optional[dict] = None, pretrained: bool = False, ): super().__init__() if isinstance(encoder_name, str): import timm kw = dict(num_classes=0, pretrained=pretrained) kw.update(encoder_kwargs or {}) base_backbone = timm.create_model(encoder_name, **kw) else: base_backbone = encoder_name with torch.no_grad(): embed_dim = _to_cls(base_backbone(torch.zeros(1, 3, 224, 224))).shape[-1] self.embed_dim = embed_dim self.n_prototypes = n_prototypes self.backbone = TeacherStudentWrapper( base_backbone, warm_init=True, base_ema_coefficient=ema_decay_start, final_ema_coefficient=ema_decay_end, ) self.projector = TeacherStudentWrapper( _build_dino_projector( embed_dim, projector_hidden_dim, projector_bottleneck_dim, n_prototypes ), warm_init=True, base_ema_coefficient=ema_decay_start, final_ema_coefficient=ema_decay_end, ) self.dino_loss = DINOv1Loss( temperature_student=temperature_student, center_momentum=center_momentum, ) self.temperature_teacher_warmup = temperature_teacher_warmup self.temperature_teacher = temperature_teacher self.warmup_epochs_temperature_teacher = warmup_epochs_temperature_teacher def _teacher_temperature(self) -> float: """Linearly warm up teacher temperature over the configured epochs.""" epoch = int(self.current_epoch) if hasattr(self, "current_epoch") else 0 warmup = self.warmup_epochs_temperature_teacher if epoch >= warmup: return self.temperature_teacher progress = epoch / max(warmup, 1) return self.temperature_teacher_warmup + progress * ( self.temperature_teacher - self.temperature_teacher_warmup )
[docs] def forward( self, global_views: Optional[Sequence[torch.Tensor]] = None, local_views: Optional[Sequence[torch.Tensor]] = None, images: Optional[torch.Tensor] = None, ) -> DINOOutput: """Forward pass. :param global_views: List of ``n_global`` tensors ``[B, C, H, W]`` (e.g. two 224x224 crops). Required in training mode. :param local_views: List of ``n_local`` tensors ``[B, C, h, w]`` (e.g. six 96x96 crops). Optional. :param images: Single batch of images for evaluation. If supplied, returns the teacher CLS embedding only. :return: :class:`DINOOutput`. """ # Eval / single-image path if images is not None: with torch.no_grad(): features = self.backbone.forward_teacher(images) cls = _to_cls(features) return DINOOutput( loss=torch.zeros((), device=cls.device, dtype=cls.dtype), embedding=cls.detach(), ) if not global_views: raise ValueError("DINO.forward needs global_views or images") global_views = list(global_views) local_views = list(local_views or []) n_global = len(global_views) n_local = len(local_views) B = global_views[0].shape[0] # Teacher: only global views global_imgs = torch.cat(global_views, dim=0) with torch.no_grad(): t_features = self.backbone.forward_teacher(global_imgs) t_cls = _to_cls(t_features) t_logits = self.projector.forward_teacher(t_cls).view(n_global, B, -1) # Student: global views (same input as teacher) s_features_g = self.backbone.forward_student(global_imgs) s_cls_g = _to_cls(s_features_g) s_logits_g = self.projector.forward_student(s_cls_g).view(n_global, B, -1) student_logits_list: List[torch.Tensor] = [s_logits_g] # Student: local views (smaller crops) if n_local > 0: local_imgs = torch.cat(local_views, dim=0) s_features_l = self.backbone.forward_student(local_imgs) s_cls_l = _to_cls(s_features_l) s_logits_l = self.projector.forward_student(s_cls_l).view(n_local, B, -1) student_logits_list.append(s_logits_l) student_logits = torch.cat(student_logits_list, dim=0) teacher_temp = self._teacher_temperature() teacher_probs = self.dino_loss.softmax_center_teacher( t_logits, teacher_temp=teacher_temp ) loss = self.dino_loss(student_logits, teacher_probs) # Queue async center update for next iteration. self.dino_loss.update_center(t_logits) return DINOOutput( loss=loss, embedding=t_cls.detach(), teacher_logits=t_logits.detach(), student_logits=student_logits, )