Source code for stable_pretraining.methods.maskfeat

"""MaskFeat: Masked feature prediction with HOG targets.

Replaces MAE's pixel reconstruction with the prediction of *Histograms of
Oriented Gradients* (HOG) at masked patch positions. HOG is a hand-crafted
feature with a strong inductive bias (scale + photometric invariance) that
keeps the encoder from learning low-level texture only.

References:
    Wei et al. "Masked Feature Prediction for Self-Supervised Visual
    Pre-Training." CVPR 2022. https://arxiv.org/abs/2112.09133
"""

from dataclasses import dataclass
from math import pi
from typing import Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.utils import ModelOutput

from stable_pretraining import Module


@dataclass
class MaskFeatOutput(ModelOutput):
    """Structured output of the :class:`MaskFeat` SSL method."""

    loss: torch.Tensor = None
    embedding: torch.Tensor = None
    predictions: Optional[torch.Tensor] = None
    targets: Optional[torch.Tensor] = None
    mask: Optional[torch.Tensor] = None


def _hog_per_patch(
    images: torch.Tensor, patch_size: int, n_bins: int = 9
) -> torch.Tensor:
    """Compute a per-patch HOG descriptor of shape ``[B, N, C * n_bins]``.

    Simple per-channel HOG using 1D Sobel-style gradient kernels.
    """
    B, C, H, W = images.shape
    p = patch_size
    assert H % p == 0 and W % p == 0
    # Channel-grouped Sobel-like gradients
    kx = torch.tensor(
        [[-1.0, 0.0, 1.0]], device=images.device, dtype=images.dtype
    ).view(1, 1, 1, 3)
    ky = torch.tensor(
        [[-1.0], [0.0], [1.0]], device=images.device, dtype=images.dtype
    ).view(1, 1, 3, 1)
    kx = kx.expand(C, 1, 1, 3)
    ky = ky.expand(C, 1, 3, 1)
    gx = F.conv2d(images, kx, padding=(0, 1), groups=C)
    gy = F.conv2d(images, ky, padding=(1, 0), groups=C)
    mag = torch.sqrt(gx * gx + gy * gy + 1e-6)
    ang = torch.atan2(gy, gx) % pi  # in [0, pi)
    bin_idx = torch.clamp((ang * n_bins / pi).long(), max=n_bins - 1)  # [B, C, H, W]

    # Per-patch histograms via scatter-add. Reshape to [B, C, gh, p, gw, p].
    gh, gw = H // p, W // p
    mag_p = mag.view(B, C, gh, p, gw, p)
    bin_p = bin_idx.view(B, C, gh, p, gw, p)
    # Permute so the patch interior is last: [B, C, gh, gw, p, p]
    mag_p = mag_p.permute(0, 1, 2, 4, 3, 5).contiguous()
    bin_p = bin_p.permute(0, 1, 2, 4, 3, 5).contiguous()
    mag_flat = mag_p.view(B, C, gh * gw, p * p)
    bin_flat = bin_p.view(B, C, gh * gw, p * p)
    hist = torch.zeros(
        B, C, gh * gw, n_bins, device=images.device, dtype=mag_flat.dtype
    )
    hist.scatter_add_(-1, bin_flat, mag_flat)
    # Standardise per-patch (zero mean, unit variance over the bin dim) — the
    # MaskFeat paper does this so the prediction target is well-conditioned
    # without being L2-flattened.
    mean = hist.mean(dim=-1, keepdim=True)
    var = hist.var(dim=-1, keepdim=True)
    hist = (hist - mean) / (var + 1e-6).sqrt()
    return hist.permute(0, 2, 1, 3).reshape(B, gh * gw, C * n_bins)


[docs] class MaskFeat(Module): """MaskFeat: predict per-patch HOG at masked positions. :param encoder_name: timm ViT name (default ``"vit_small_patch16_224"``). :param patch_size: Patch size (default 16, must match encoder). :param mask_ratio: Fraction of patches masked (default 0.4). :param n_hog_bins: HOG orientation bins (default 9). :param image_size: Input size (default 224). :param in_channels: Image channels (default 3). :param pretrained: Load pretrained timm weights. """ def __init__( self, encoder_name: Union[str, nn.Module] = "vit_small_patch16_224", patch_size: int = 16, mask_ratio: float = 0.4, n_hog_bins: int = 9, image_size: int = 224, in_channels: int = 3, pretrained: bool = False, ): super().__init__() if isinstance(encoder_name, str): import timm self.encoder = timm.create_model( encoder_name, num_classes=0, pretrained=pretrained ) else: self.encoder = encoder_name with torch.no_grad(): seq = self.encoder.forward_features( torch.zeros(1, in_channels, image_size, image_size) ) self._has_cls = ( hasattr(self.encoder, "cls_token") and self.encoder.cls_token is not None and seq.shape[1] > 1 ) embed_dim = seq.shape[-1] self.embed_dim = embed_dim self.patch_size = patch_size self.mask_ratio = mask_ratio self.n_hog_bins = n_hog_bins self.image_size = image_size self.in_channels = in_channels self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) nn.init.trunc_normal_(self.mask_token, std=0.02) self.head = nn.Linear(embed_dim, in_channels * n_hog_bins) def _random_mask(self, B, N, device): n_mask = int(round(N * self.mask_ratio)) noise = torch.rand(B, N, device=device) order = noise.argsort(dim=1) mask = torch.zeros(B, N, device=device) mask.scatter_(1, order[:, :n_mask], 1.0) return mask def _encode(self, images, mask): vit = self.encoder x = vit.patch_embed(images) m = mask.unsqueeze(-1) x = x * (1 - m) + self.mask_token.expand_as(x) * m if self._has_cls: cls = vit.cls_token.expand(x.shape[0], -1, -1) x = torch.cat([cls, x], dim=1) x = x + vit.pos_embed x = vit.pos_drop(x) x = vit.blocks(x) x = vit.norm(x) if self._has_cls: x = x[:, 1:] return x
[docs] def forward(self, images: torch.Tensor) -> MaskFeatOutput: B = images.shape[0] if not self.training: features = self.encoder.forward_features(images) cls = features[:, 0] if self._has_cls else features.mean(dim=1) return MaskFeatOutput( loss=torch.zeros((), device=images.device, dtype=images.dtype), embedding=cls, ) with torch.no_grad(): target = _hog_per_patch( images, patch_size=self.patch_size, n_bins=self.n_hog_bins ) N = target.shape[1] mask = self._random_mask(B, N, device=images.device) encoded = self._encode(images, mask) pred = self.head(encoded) flat_mask = mask.bool().view(-1) flat_pred = pred.reshape(-1, pred.shape[-1])[flat_mask] flat_target = target.reshape(-1, target.shape[-1])[flat_mask] loss = F.mse_loss(flat_pred, flat_target) # Mean of all patch tokens (unmasked + masked) — using only unmasked # tokens biases the probe toward unmasked content, which damages it. embedding = encoded.mean(dim=1) return MaskFeatOutput( loss=loss, embedding=embedding.detach(), predictions=pred, targets=target, mask=mask, )