Source code for stable_pretraining.methods.cmae

"""CMAE: Contrastive Masked AutoEncoders.

Combines MAE-style pixel reconstruction with a SimSiam-/BYOL-style
contrastive loss between two views. The student encodes a masked view; an
EMA target encoder encodes a different (un-masked) view; the loss is
``MAE_recon + lambda * contrastive``.

References:
    Huang et al. "Contrastive Masked Autoencoders are Stronger Vision
    Learners." TPAMI 2023. https://arxiv.org/abs/2207.13532
"""

from dataclasses import dataclass
from typing import Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.utils import ModelOutput

from stable_pretraining import Module
from stable_pretraining.backbone import TeacherStudentWrapper, patchify


@dataclass
class CMAEOutput(ModelOutput):
    """Structured output of the :class:`CMAE` SSL method."""

    loss: torch.Tensor = None
    loss_recon: torch.Tensor = None
    loss_contrast: torch.Tensor = None
    embedding: torch.Tensor = None


def _projector(in_dim: int, hidden_dim: int, out_dim: int) -> nn.Module:
    return nn.Sequential(
        nn.Linear(in_dim, hidden_dim),
        nn.BatchNorm1d(hidden_dim),
        nn.ReLU(inplace=True),
        nn.Linear(hidden_dim, out_dim),
    )


def _predictor(in_dim: int, hidden_dim: int) -> nn.Module:
    return nn.Sequential(
        nn.Linear(in_dim, hidden_dim),
        nn.BatchNorm1d(hidden_dim),
        nn.ReLU(inplace=True),
        nn.Linear(hidden_dim, in_dim),
    )


[docs] class CMAE(Module): """CMAE: MAE pixel loss + EMA contrastive loss. :param encoder_name: timm ViT name (default ``"vit_small_patch16_224"``). :param patch_size: Patch size (default 16). :param mask_ratio: Mask ratio (default 0.75, as in MAE). :param projector_dim: Contrastive projector hidden/out dim (default 256). :param contrast_weight: Weight on the contrastive term (default 1.0). :param ema_decay_start: Initial EMA (default 0.99). :param ema_decay_end: Final EMA (default 1.0). :param image_size: Input size (default 224). :param in_channels: Channels (default 3). :param pretrained: Load pretrained timm weights. """ def __init__( self, encoder_name: Union[str, nn.Module] = "vit_small_patch16_224", patch_size: int = 16, mask_ratio: float = 0.75, projector_dim: int = 256, contrast_weight: float = 1.0, ema_decay_start: float = 0.99, ema_decay_end: float = 1.0, image_size: int = 224, in_channels: int = 3, pretrained: bool = False, ): super().__init__() if isinstance(encoder_name, str): import timm base = timm.create_model(encoder_name, num_classes=0, pretrained=pretrained) else: base = encoder_name with torch.no_grad(): seq = base.forward_features( torch.zeros(1, in_channels, image_size, image_size) ) self._has_cls = ( hasattr(base, "cls_token") and base.cls_token is not None and seq.shape[1] > 1 ) embed_dim = seq.shape[-1] self.embed_dim = embed_dim self.patch_size = patch_size self.mask_ratio = mask_ratio self.contrast_weight = contrast_weight self.image_size = image_size self.in_channels = in_channels self.backbone = TeacherStudentWrapper( base, warm_init=True, base_ema_coefficient=ema_decay_start, final_ema_coefficient=ema_decay_end, ) self.projector = TeacherStudentWrapper( _projector(embed_dim, embed_dim, projector_dim), warm_init=True, base_ema_coefficient=ema_decay_start, final_ema_coefficient=ema_decay_end, ) self.predictor = _predictor(projector_dim, projector_dim * 2) self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) nn.init.trunc_normal_(self.mask_token, std=0.02) self.recon_head = nn.Linear(embed_dim, in_channels * patch_size * patch_size) def _encode(self, vit, images, mask=None): x = vit.patch_embed(images) if mask is not None: m = mask.unsqueeze(-1) x = x * (1 - m) + self.mask_token.expand_as(x) * m if self._has_cls: cls = vit.cls_token.expand(x.shape[0], -1, -1) x = torch.cat([cls, x], dim=1) x = x + vit.pos_embed x = vit.pos_drop(x) x = vit.blocks(x) return vit.norm(x) def _split(self, features): if features.ndim == 2: return features, None if self._has_cls: return features[:, 0], features[:, 1:] return features.mean(dim=1), features def _random_mask(self, B, N, device): n_mask = int(round(N * self.mask_ratio)) noise = torch.rand(B, N, device=device) order = noise.argsort(dim=1) mask = torch.zeros(B, N, device=device) mask.scatter_(1, order[:, :n_mask], 1.0) return mask
[docs] def forward( self, view1: torch.Tensor, view2: Optional[torch.Tensor] = None, ) -> CMAEOutput: if view2 is None: with torch.no_grad(): feats = self.backbone.forward_teacher(view1) cls, _ = self._split(feats) return CMAEOutput( loss=torch.zeros((), device=cls.device, dtype=cls.dtype), embedding=cls.detach(), ) # Student: masked view1 → reconstruct + contrast. B = view1.shape[0] with torch.no_grad(): n_patches = self.backbone.student.patch_embed(view1[:1]).shape[1] mask = self._random_mask(B, n_patches, device=view1.device) s_feats = self._encode(self.backbone.student, view1, mask=mask) s_cls, s_patches = self._split(s_feats) zs = self.projector.forward_student(s_cls) ps = self.predictor(zs) # Reconstruction (only on masked positions, using normalised pixel target) target = patchify(view1, (self.in_channels, self.patch_size, self.patch_size)) m = target.mean(dim=-1, keepdim=True) v = target.var(dim=-1, keepdim=True) target = (target - m) / (v + 1e-6).sqrt() recon = self.recon_head(s_patches) loss_per = F.mse_loss(recon, target, reduction="none").mean(dim=-1) loss_recon = (loss_per * mask).sum() / mask.sum().clamp(min=1.0) # Teacher: unmasked view2 → contrastive target. with torch.no_grad(): t_feats = self._encode(self.backbone.teacher, view2, mask=None) t_cls, _ = self._split(t_feats) zt = self.projector.forward_teacher(t_cls) # Negative cosine similarity (BYOL-style) ps_n = F.normalize(ps, dim=-1) zt_n = F.normalize(zt, dim=-1) loss_contrast = -(ps_n * zt_n).sum(dim=-1).mean() loss = loss_recon + self.contrast_weight * loss_contrast return CMAEOutput( loss=loss, loss_recon=loss_recon.detach(), loss_contrast=loss_contrast.detach(), embedding=t_cls.detach(), )