Source code for stable_pretraining.utils.mae_loss

import torch
from math import prod
import torch.nn as nn
import torch.nn.functional as F
from typing import Literal, Optional, Callable


[docs] def patchify(x, patch_size): """Convert tensor to patches along the last len(patch_size) dimensions. Splits the last k spatial dimensions into non-overlapping patches and flattens them into a sequence of patch tokens. This is the standard patchification used in Vision Transformers (ViT), MAE, etc. :param x: Input tensor of shape (..., S_0, S_1, ..., S_{k-1}) where the last k dimensions are spatial and will be patchified. Leading dimensions are preserved (e.g., batch, channels). :param patch_size: Tuple/list of k patch sizes (p_0, p_1, ..., p_{k-1}). Each spatial dim S_i must be divisible by p_i. :return: Patches of shape (..., T, P) where: - T = prod(S_i // p_i) is the number of patches - P = prod(p_i) is the number of elements per patch Examples:: >>> import torch # ================================================================= # 2D Images: (N, C, H, W) -> (N, C, num_patches, patch_elements) # ================================================================= >>> images = torch.randn(8, 3, 224, 224) >>> patches = patchify(images, patch_size=(16, 16)) >>> patches.shape torch.Size([8, 3, 196, 256]) # 196 = 14*14 patches, 256 = 16*16 elements # Non-square patches >>> patches = patchify(images, patch_size=(14, 16)) >>> patches.shape torch.Size([8, 3, 224, 224]) # 16*14=224 patches, 14*16=224 elements # ================================================================= # 3D Volumes: (N, C, D, H, W) -> (N, C, num_patches, patch_elements) # ================================================================= >>> volumes = torch.randn(4, 1, 64, 128, 128) >>> patches = patchify(volumes, patch_size=(8, 16, 16)) >>> patches.shape torch.Size([4, 1, 512, 2048]) # 8*8*8=512 patches, 8*16*16=2048 elements # ================================================================= # 1D Signals: (N, C, L) -> (N, C, num_patches, patch_elements) # ================================================================= >>> signals = torch.randn(16, 2, 1024) >>> patches = patchify(signals, patch_size=(64,)) >>> patches.shape torch.Size([16, 2, 16, 64]) # 16 patches of 64 elements each # ================================================================= # Flexible batch dimensions # ================================================================= # No batch dims: (H, W) -> (T, P) >>> image = torch.randn(224, 224) >>> patches = patchify(image, patch_size=(16, 16)) >>> patches.shape torch.Size([196, 256]) # Multiple batch dims: (B1, B2, C, H, W) -> (B1, B2, C, T, P) >>> x = torch.randn(2, 4, 3, 224, 224) >>> patches = patchify(x, patch_size=(16, 16)) >>> patches.shape torch.Size([2, 4, 3, 196, 256]) # ================================================================= # Typical ViT usage (channels folded into patches) # ================================================================= >>> images = torch.randn(8, 3, 224, 224) >>> # Reshape to (N, H, W, C) then patchify spatial dims >>> x = images.permute(0, 2, 3, 1) # (8, 224, 224, 3) >>> patches = patchify(x, patch_size=(16, 16)) # (8, 196, 768) >>> patches.shape # 768 = 16 * 16 * 3 torch.Size([8, 196, 768]) See Also: :func:`unpatchify`: Inverse operation to reconstruct the original tensor. """ patch_size = tuple(patch_size) k = len(patch_size) batch_shape = x.shape[:-k] spatial_shape = x.shape[-k:] # Validate divisibility for i, (s, p) in enumerate(zip(spatial_shape, patch_size)): if s % p != 0: raise ValueError( f"Spatial dim {i} (size {s}) must be divisible by patch_size[{i}]={p}" ) # Compute grid size (number of patches per spatial dim) grid_size = tuple(s // p for s, p in zip(spatial_shape, patch_size)) # (..., S_0, S_1, ...) -> (..., n_0, p_0, n_1, p_1, ...) interleaved = sum(zip(grid_size, patch_size), ()) x = x.reshape(*batch_shape, *interleaved) # (..., n_0, p_0, n_1, p_1, ...) -> (..., n_0, n_1, ..., p_0, p_1, ...) b = len(batch_shape) perm = (*range(b), *range(b, b + 2 * k, 2), *range(b + 1, b + 2 * k, 2)) x = x.permute(perm) # (..., n_0, n_1, ..., p_0, p_1, ...) -> (..., T, P) return x.reshape(*batch_shape, prod(grid_size), prod(patch_size))
[docs] def unpatchify(patches, patch_size, grid_size=None): """Reconstruct tensor from patches (inverse of patchify). Reverses the patchification process, reconstructing the original spatial dimensions from a sequence of flattened patches. :param patches: Patch tensor of shape (..., T, P) where: - T is the number of patches - P is the number of elements per patch (must equal prod(patch_size)) :param patch_size: Tuple/list of k patch sizes (p_0, p_1, ..., p_{k-1}). :param grid_size: Tuple/list of k grid sizes (n_0, n_1, ..., n_{k-1}) where n_i is the number of patches along spatial dimension i. If None, assumes a uniform grid (T must be a perfect k-th power). :return: Reconstructed tensor of shape (..., S_0, S_1, ..., S_{k-1}) where S_i = n_i * p_i. Examples:: >>> import torch # ================================================================= # 2D Images: Roundtrip # ================================================================= >>> images = torch.randn(8, 3, 224, 224) >>> patches = patchify(images, patch_size=(16, 16)) >>> reconstructed = unpatchify(patches, patch_size=(16, 16)) >>> torch.allclose(images, reconstructed) True # ================================================================= # 3D Volumes: Roundtrip # ================================================================= >>> volumes = torch.randn(4, 1, 64, 128, 128) >>> patches = patchify(volumes, patch_size=(8, 16, 16)) >>> reconstructed = unpatchify(patches, patch_size=(8, 16, 16)) >>> torch.allclose(volumes, reconstructed) True # ================================================================= # 1D Signals: Roundtrip # ================================================================= >>> signals = torch.randn(16, 2, 1024) >>> patches = patchify(signals, patch_size=(64,)) >>> reconstructed = unpatchify(patches, patch_size=(64,)) >>> torch.allclose(signals, reconstructed) True # ================================================================= # Non-square grid (must specify grid_size) # ================================================================= >>> images = torch.randn(8, 3, 224, 256) # Non-square image >>> patches = patchify(images, patch_size=(16, 16)) >>> patches.shape torch.Size([8, 3, 224, 256]) # 14*16=224 patches >>> reconstructed = unpatchify(patches, patch_size=(16, 16), grid_size=(14, 16)) >>> torch.allclose(images, reconstructed) True # ================================================================= # MAE-style reconstruction (predict pixels from patch embeddings) # ================================================================= >>> # Decoder outputs: (N, num_patches, patch_pixels) >>> predictions = torch.randn(8, 196, 768) # 768 = 16*16*3 >>> # Reconstruct to (N, num_patches, H, W, C) then permute >>> images = unpatchify(predictions, patch_size=(16, 16)) # (8, 224, 224) >>> # For RGB: reshape last dim and permute >>> predictions = torch.randn(8, 196, 768) >>> images = unpatchify(predictions.reshape(8, 196, 16, 16, 3), patch_size=(16, 16)) >>> images = images.permute(0, 3, 1, 2) # (8, 3, 224, 224) # ================================================================= # Explicit grid_size for non-uniform grids # ================================================================= >>> patches = torch.randn(4, 168, 256) # 168 = 12 * 14 patches >>> images = unpatchify(patches, patch_size=(16, 16), grid_size=(12, 14)) >>> images.shape torch.Size([4, 192, 224]) # 12*16=192, 14*16=224 # ================================================================= # Error case: Cannot infer non-uniform grid # ================================================================= >>> patches = torch.randn(4, 168, 256) # 168 is not a perfect square >>> unpatchify(patches, patch_size=(16, 16)) # Raises ValueError ValueError: Cannot infer grid: T=168 is not a perfect 2-th power See Also: :func:`patchify`: Forward operation to convert tensors to patches. """ patch_size = tuple(patch_size) k = len(patch_size) batch_shape = patches.shape[:-2] T, patch_elements = patches.shape[-2:] if patch_elements != prod(patch_size): raise ValueError( f"patches last dim {patch_elements} != prod(patch_size)={prod(patch_size)}" ) # Infer or validate grid_size if grid_size is None: n = round(T ** (1.0 / k)) if n**k != T: raise ValueError( f"Cannot infer grid: T={T} is not a perfect {k}-th power. " f"Please provide grid_size explicitly." ) grid_size = (n,) * k else: grid_size = tuple(grid_size) if len(grid_size) != k: raise ValueError( f"grid_size has {len(grid_size)} dims but patch_size has {k} dims" ) if prod(grid_size) != T: raise ValueError(f"prod(grid_size)={prod(grid_size)} != num_patches T={T}") # (..., T, P) -> (..., n_0, n_1, ..., p_0, p_1, ...) x = patches.reshape(*batch_shape, *grid_size, *patch_size) # (..., n_0, n_1, ..., p_0, p_1, ...) -> (..., n_0, p_0, n_1, p_1, ...) b = len(batch_shape) perm = (*range(b), *sum(zip(range(b, b + k), range(b + k, b + 2 * k)), ())) x = x.permute(perm) # (..., n_0, p_0, n_1, p_1, ...) -> (..., S_0, S_1, ...) spatial_shape = tuple(n * p for n, p in zip(grid_size, patch_size)) return x.reshape(*batch_shape, *spatial_shape)
[docs] class MAELoss(nn.Module): """Modular MAE reconstruction loss with configurable loss functions. Supports MSE, cosine similarity, and custom loss functions with optional per-patch normalization. :param patch_size: Size of each square patch (default: 16) :param loss_type: Loss function type - 'mse', 'cosine', or 'smooth_l1' (default: 'mse') :param mask_only: If True, compute loss only on masked patches (default: True) :param patch_normalize: If True, normalize each target patch to zero mean/unit var (default: True) :param reduction: How to reduce patch losses - 'mean' or 'sum' (default: 'mean') Examples:: >>> loss_fn = MAELoss(patch_size=16, loss_type='mse') >>> loss = loss_fn(pred, imgs, mask) >>> # Cosine similarity loss >>> loss_fn = MAELoss(patch_size=16, loss_type='cosine') >>> loss = loss_fn(pred, imgs, mask) >>> # Custom loss function >>> loss_fn = MAELoss(patch_size=16, loss_type='custom') >>> loss_fn.register_custom_loss(lambda p, t: (p - t).abs().mean(dim=-1)) >>> loss = loss_fn(pred, imgs, mask) """ LOSS_TYPES = Literal["mse", "cosine", "smooth_l1", "custom"] def __init__( self, patch_size: int = 16, loss_type: LOSS_TYPES = "mse", mask_only: bool = True, patch_normalize: bool = True, reduction: Literal["mean", "sum"] = "mean", ): super().__init__() self.patch_size = patch_size self.loss_type = loss_type self.mask_only = mask_only self.patch_normalize = patch_normalize self.reduction = reduction self._custom_loss_fn: Optional[Callable] = None
[docs] def register_custom_loss( self, fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] ): """Register a custom loss function. :param fn: Callable taking (pred, target) both of shape (N, T, P) and returning per-patch losses of shape (N, T). """ self._custom_loss_fn = fn
def _compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Compute per-patch loss based on loss_type. :param pred: Predictions, shape (N, T, P) :param target: Targets, shape (N, T, P) :return: Per-patch losses, shape (N, T) """ if self.loss_type == "mse": return (pred - target).pow(2).mean(dim=-1) elif self.loss_type == "cosine": # Cosine similarity: 1 = identical, -1 = opposite # Loss: 1 - similarity (so 0 = perfect, 2 = worst) similarity = F.cosine_similarity(pred, target, dim=-1) return 1 - similarity elif self.loss_type == "smooth_l1": # Huber loss, less sensitive to outliers than MSE return F.smooth_l1_loss(pred, target, reduction="none").mean(dim=-1) elif self.loss_type == "custom": if self._custom_loss_fn is None: raise ValueError( "loss_type='custom' but no custom loss registered. " "Call register_custom_loss() first." ) return self._custom_loss_fn(pred, target) else: raise ValueError(f"Unknown loss_type: {self.loss_type}") def _validate_inputs( self, pred: torch.Tensor, imgs: torch.Tensor, mask: torch.Tensor ): """Validate input tensors for correctness.""" p = self.patch_size # NaN/Inf checks assert not torch.isnan(imgs).any(), "imgs contains NaN values" assert not torch.isinf(imgs).any(), "imgs contains Inf values" assert not torch.isnan(pred).any(), "pred contains NaN values" assert not torch.isinf(pred).any(), "pred contains Inf values" # Shape checks assert imgs.ndim == 4, f"imgs must be 4D (N, C, H, W), got {imgs.shape}" N, C, H, W = imgs.shape assert H % p == 0, f"Height {H} must be divisible by patch_size {p}" assert W % p == 0, f"Width {W} must be divisible by patch_size {p}" T_expected = (H // p) * (W // p) pixels_per_patch = p * p * C assert pred.ndim == 3, f"pred must be 3D (N, T, D), got {pred.shape}" assert pred.shape == ( N, T_expected, pixels_per_patch, ), ( f"pred shape {pred.shape} != expected ({N}, {T_expected}, {pixels_per_patch})" ) assert mask.ndim == 2, f"mask must be 2D (N, T), got {mask.shape}" assert mask.shape == ( N, T_expected, ), f"mask shape {mask.shape} != expected ({N}, {T_expected})" if self.mask_only: assert mask.sum() > 0, "mask has no masked patches" # Device/dtype consistency assert pred.device == imgs.device and mask.device == imgs.device assert pred.dtype == imgs.dtype
[docs] def patchify(self, imgs: torch.Tensor) -> torch.Tensor: """Convert images to patches. :param imgs: Images of shape (N, C, H, W) :return: Patches of shape (N, T, P) where T = num_patches, P = pixels_per_patch """ p = self.patch_size N, C, H, W = imgs.shape # (N, C, H, W) -> (N, C, H//p, p, W//p, p) x = imgs.unfold(2, p, p).unfold(3, p, p) # (N, C, nH, nW, p, p) -> (N, nH, nW, p, p, C) -> (N, T, P) x = x.permute(0, 2, 3, 4, 5, 1).reshape(N, -1, p * p * C) return x
[docs] def forward( self, pred: torch.Tensor, imgs: torch.Tensor, mask: torch.Tensor, debug: bool = False, ) -> torch.Tensor: """Compute MAE reconstruction loss. :param pred: Decoder predictions, shape (N, T, patch_size² × C) :param imgs: Original images, shape (N, C, H, W) :param mask: Binary mask, shape (N, T), 1 = masked (compute loss) :param debug: If True, print debug statistics :return: Scalar loss value """ self._validate_inputs(pred, imgs, mask) # Patchify target images target = self.patchify(imgs) if debug: self._print_debug(pred, target, mask) # Per-patch normalization (optional) if self.patch_normalize: mean = target.mean(dim=-1, keepdim=True) var = target.var(dim=-1, keepdim=True) target = (target - mean) / (var + 1e-6).sqrt() # Compute per-patch loss loss = self._compute_loss(pred, target) # (N, T) # Apply mask and reduce if self.mask_only: if self.reduction == "mean": loss = (loss * mask).sum() / mask.sum() else: loss = (loss * mask).sum() else: if self.reduction == "mean": loss = loss.mean() else: loss = loss.sum() assert not torch.isnan(loss), "Loss is NaN" return loss
def _print_debug( self, pred: torch.Tensor, target: torch.Tensor, mask: torch.Tensor ): """Print debug statistics.""" print("=" * 60) print(f"MAE Loss Debug | loss_type={self.loss_type}") print("=" * 60) print( f"pred: shape={tuple(pred.shape)}, " f"min={pred.min():.4f}, max={pred.max():.4f}, " f"mean={pred.mean():.4f}, std={pred.std():.4f}" ) print( f"target: shape={tuple(target.shape)}, " f"min={target.min():.4f}, max={target.max():.4f}, " f"mean={target.mean():.4f}, std={target.std():.4f}" ) print( f"mask: {mask.sum().item()}/{mask.numel()} patches masked " f"({mask.float().mean().item() * 100:.1f}%)" ) print("=" * 60)