Source code for stable_pretraining.methods.data2vec

"""data2vec: Predicting contextualised representations.

Self-supervised learning by having a student predict the EMA-teacher's
contextualised representation (top-K block average of patch tokens) at
masked positions. No augmentations, no negatives, modality-agnostic.

References:
    Baevski et al. "data2vec: A General Framework for Self-supervised
    Learning in Speech, Vision and Language." ICML 2022.
    https://arxiv.org/abs/2202.03555
"""

from dataclasses import dataclass
from typing import Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.utils import ModelOutput

from stable_pretraining import Module
from stable_pretraining.backbone import TeacherStudentWrapper


@dataclass
class Data2VecOutput(ModelOutput):
    """Structured output of the :class:`Data2Vec` SSL method."""

    loss: torch.Tensor = None
    embedding: torch.Tensor = None
    predictions: Optional[torch.Tensor] = None
    target: Optional[torch.Tensor] = None
    mask: Optional[torch.Tensor] = None


class _BlockHook:
    """Capture the output of every transformer block via forward hooks."""

    def __init__(self, blocks: nn.ModuleList):
        self.outputs: list = []
        self._handles = [b.register_forward_hook(self._hook) for b in blocks]

    def _hook(self, module, inputs, output):
        self.outputs.append(output if isinstance(output, torch.Tensor) else output[0])

    def reset(self):
        self.outputs = []

    def remove(self):
        for h in self._handles:
            h.remove()


[docs] class Data2Vec(Module): """data2vec for vision: predict EMA-teacher block-averaged features. :param encoder_name: timm ViT name (default ``"vit_small_patch16_224"``). :param top_k_blocks: Number of top transformer blocks averaged on the teacher side to form the prediction target (default 6). :param mask_ratio: Fraction of patch tokens masked on the student input (default 0.6). Masked tokens are replaced by a learnable token before the encoder. :param ema_decay_start: Initial teacher EMA (default 0.999). :param ema_decay_end: Final teacher EMA (default 0.9999). :param image_size: Input size (default 224). :param pretrained: Load pretrained timm weights. """ def __init__( self, encoder_name: Union[str, nn.Module] = "vit_small_patch16_224", top_k_blocks: int = 6, mask_ratio: float = 0.6, ema_decay_start: float = 0.999, ema_decay_end: float = 0.9999, image_size: int = 224, pretrained: bool = False, ): super().__init__() if isinstance(encoder_name, str): import timm base = timm.create_model(encoder_name, num_classes=0, pretrained=pretrained) else: base = encoder_name with torch.no_grad(): embed_dim = base(torch.zeros(1, 3, image_size, image_size)).shape[-1] self.embed_dim = embed_dim self.top_k_blocks = top_k_blocks self.mask_ratio = mask_ratio self.image_size = image_size self.encoder = TeacherStudentWrapper( base, warm_init=True, base_ema_coefficient=ema_decay_start, final_ema_coefficient=ema_decay_end, ) self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) nn.init.trunc_normal_(self.mask_token, std=0.02) # Linear regression head that maps encoded features to the target dim self.regressor = nn.Linear(embed_dim, embed_dim) def _patches_for_vit(self, vit: nn.Module, images: torch.Tensor) -> torch.Tensor: return vit.patch_embed(images) def _encode_with_mask( self, vit: nn.Module, images: torch.Tensor, mask: Optional[torch.Tensor] ) -> torch.Tensor: x = self._patches_for_vit(vit, images) if mask is not None: m = mask.unsqueeze(-1) x = x * (1 - m) + self.mask_token.expand_as(x) * m if hasattr(vit, "cls_token") and vit.cls_token is not None: cls = vit.cls_token.expand(x.shape[0], -1, -1) x = torch.cat([cls, x], dim=1) x = x + vit.pos_embed x = vit.pos_drop(x) x = vit.blocks(x) x = vit.norm(x) if hasattr(vit, "cls_token") and vit.cls_token is not None: x = x[:, 1:] return x def _teacher_target(self, images: torch.Tensor) -> torch.Tensor: """Return the EMA teacher's averaged last-K-block patch features. Runs the unmasked image through the EMA teacher and averages the last K block outputs at patch positions. We install + remove the hook inside this method so we never capture the student's forward (which would corrupt the target). """ vit = self.encoder.teacher with torch.no_grad(): hook = _BlockHook(vit.blocks) try: _ = self._encode_with_mask(vit, images, mask=None) blocks = list(hook.outputs[-self.top_k_blocks :]) finally: hook.remove() # Drop CLS column for each block output if present. cls_offset = 1 if hasattr(vit, "cls_token") and vit.cls_token is not None else 0 cleaned = [b[:, cls_offset:] for b in blocks] target = torch.stack(cleaned, dim=0).mean(dim=0) # Per-token layer norm (paper: stability of the regression target). target = F.layer_norm(target, [target.shape[-1]]) return target def _random_mask(self, B: int, N: int, device) -> torch.Tensor: n_mask = int(round(N * self.mask_ratio)) noise = torch.rand(B, N, device=device) order = noise.argsort(dim=1) mask = torch.zeros(B, N, device=device) mask.scatter_(1, order[:, :n_mask], 1.0) return mask
[docs] def forward(self, images: torch.Tensor) -> Data2VecOutput: """Forward pass. :param images: ``[B, C, H, W]``. """ B, _, H, W = images.shape if not self.training: with torch.no_grad(): feats = self.encoder.forward_teacher(images) cls = feats[:, 0] if feats.ndim == 3 else feats return Data2VecOutput( loss=torch.zeros((), device=images.device, dtype=images.dtype), embedding=cls.detach(), ) # Build random mask in the patch grid vit = self.encoder.student with torch.no_grad(): n_patches = self._patches_for_vit(vit, images).shape[1] mask = self._random_mask(B, n_patches, device=images.device) # Student encodes the masked image student_tokens = self._encode_with_mask(vit, images, mask) predictions = self.regressor(student_tokens) # Teacher provides the target representation (no mask) target = self._teacher_target(images) # Smooth-L1 loss on masked positions only diff = F.smooth_l1_loss(predictions, target, beta=2.0, reduction="none").mean( dim=-1 ) loss = (diff * mask).sum() / mask.sum().clamp(min=1.0) # Embedding for online probes: mean of all student patch tokens. embedding = student_tokens.mean(dim=1) return Data2VecOutput( loss=loss, embedding=embedding.detach(), predictions=predictions, target=target.detach(), mask=mask, )