"""SimMIM: Simple Framework for Masked Image Modeling.
Predicts raw pixel values for masked patches via a 1-layer linear decoder.
Unlike MAE, SimMIM passes both visible and mask tokens through the encoder
and uses a trivial decoder, making it cleaner to integrate with any ViT.
References:
Xie et al. "SimMIM: A Simple Framework for Masked Image Modeling."
CVPR 2022. https://arxiv.org/abs/2111.09886
"""
from dataclasses import dataclass
from typing import Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.utils import ModelOutput
from stable_pretraining import Module
from stable_pretraining.backbone import patchify
@dataclass
class SimMIMOutput(ModelOutput):
"""Structured output of the :class:`SimMIM` SSL method."""
loss: torch.Tensor = None
embedding: torch.Tensor = None
predictions: Optional[torch.Tensor] = None
mask: Optional[torch.Tensor] = None
[docs]
class SimMIM(Module):
"""SimMIM masked image modeling.
:param encoder_name: timm model name (default ``"vit_small_patch16_224"``).
:param patch_size: Patch size (must match the encoder's).
:param mask_ratio: Fraction of patches to mask (default 0.6, paper used 0.6).
:param in_channels: Image channels (default 3).
:param image_size: Input image size (default 224).
:param pretrained: Load pretrained timm weights.
"""
def __init__(
self,
encoder_name: Union[str, nn.Module] = "vit_small_patch16_224",
patch_size: int = 16,
mask_ratio: float = 0.6,
in_channels: int = 3,
image_size: int = 224,
pretrained: bool = False,
):
super().__init__()
if isinstance(encoder_name, str):
import timm
self.encoder = timm.create_model(
encoder_name, num_classes=0, pretrained=pretrained
)
else:
self.encoder = encoder_name
with torch.no_grad():
embed_dim = self.encoder(
torch.zeros(1, in_channels, image_size, image_size)
).shape[-1]
self.embed_dim = embed_dim
self.patch_size = patch_size
self.mask_ratio = mask_ratio
self.in_channels = in_channels
self.image_size = image_size
# Learnable mask token (same dim as patch embedding input space)
self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
nn.init.trunc_normal_(self.mask_token, std=0.02)
# 1-layer linear decoder: embed_dim -> per-patch pixels
self.decoder = nn.Linear(embed_dim, patch_size * patch_size * in_channels)
def _random_mask(self, B: int, N: int, device) -> torch.Tensor:
"""Sample a 0/1 mask per (batch, patch) with given mask_ratio."""
n_mask = int(round(N * self.mask_ratio))
noise = torch.rand(B, N, device=device)
order = noise.argsort(dim=1)
mask = torch.zeros(B, N, device=device)
mask.scatter_(1, order[:, :n_mask], 1.0)
return mask # 1 = masked
def _encode_with_mask(
self, images: torch.Tensor, mask: torch.Tensor
) -> torch.Tensor:
"""Patch-embed, replace masked tokens, then run encoder blocks.
Works with timm ViTs that expose ``patch_embed``, ``cls_token`` (optional),
``pos_embed``, ``blocks``, ``norm``.
"""
vit = self.encoder
B = images.shape[0]
x = vit.patch_embed(images) # [B, N, D]
# Replace masked positions with mask_token
m = mask.unsqueeze(-1) # [B, N, 1]
x = x * (1 - m) + self.mask_token.expand_as(x) * m
# Add CLS token if the model uses one
if hasattr(vit, "cls_token") and vit.cls_token is not None:
cls = vit.cls_token.expand(B, -1, -1)
x = torch.cat([cls, x], dim=1)
# Positional embedding
x = x + vit.pos_embed
x = vit.pos_drop(x)
x = vit.blocks(x)
x = vit.norm(x)
# Drop CLS for patch reconstruction
if hasattr(vit, "cls_token") and vit.cls_token is not None:
x = x[:, 1:]
return x
[docs]
def forward(self, images: torch.Tensor) -> SimMIMOutput:
"""Forward pass.
:param images: ``[B, C, H, W]`` images.
:return: :class:`SimMIMOutput`.
"""
B = images.shape[0]
H = W = self.image_size
p = self.patch_size
N = (H // p) * (W // p)
if not self.training:
# Plain encoding for downstream tasks
features = self.encoder.forward_features(images)
cls = features[:, 0] if features.ndim == 3 else features
return SimMIMOutput(
loss=torch.zeros((), device=images.device, dtype=images.dtype),
embedding=cls,
)
mask = self._random_mask(B, N, device=images.device)
encoded = self._encode_with_mask(images, mask) # [B, N, D]
pred_pixels = self.decoder(encoded) # [B, N, P]
target = patchify(images, patch_size=(self.in_channels, p, p)) # [B, N, C*p*p]
loss_per = F.l1_loss(pred_pixels, target, reduction="none").mean(
dim=-1
) # [B, N]
loss = (loss_per * mask).sum() / mask.sum().clamp(min=1.0)
# Embedding for probes: mean of *all* patch tokens.
embedding = encoded.mean(dim=1)
return SimMIMOutput(
loss=loss,
embedding=embedding.detach(),
predictions=pred_pixels,
mask=mask,
)