"""MAE: Masked Autoencoders Are Scalable Vision Learners.
Self-supervised learning via reconstructing masked patches from visible patches.
References:
He et al. "Masked Autoencoders Are Scalable Vision Learners." CVPR 2022.
https://arxiv.org/abs/2111.06377
Example::
from stable_pretraining.backbone import MAE
import lightning as pl
# Create model
model = MAE("vit_base_patch16_224", mask_ratio=0.75)
# Training
model.train()
output = model(images)
output.loss.backward()
# Get encoder for downstream
encoder = model.encoder
"""
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
from typing import Union
from transformers.utils import ModelOutput
from stable_pretraining.backbone import MAEDecoder, MaskedEncoder, PatchMasking
from stable_pretraining.utils import MAELoss
from stable_pretraining import Module
[docs]
@dataclass
class MAEOutput(ModelOutput):
"""Output from MAE forward pass.
:ivar loss: Reconstruction loss (MSE on masked patches)
:ivar predictions: Reconstructed patches [B, N, patch_dim]
:ivar mask: Binary mask where 1=masked, 0=visible [B, N]
:ivar num_masked: Number of masked patches
:ivar num_visible: Number of visible patches
"""
loss: torch.Tensor = None
predictions: torch.Tensor = None
mask: torch.Tensor = None
num_masked: int = None
num_visible: int = None
[docs]
class MAE(Module):
"""MAE: Masked Autoencoders Are Scalable Vision Learners.
Architecture:
- **Encoder**: ViT processing only visible (unmasked) patches
- **Decoder**: Lightweight transformer reconstructing masked patches
- **Target**: Normalized pixel values of masked patches
:param model_or_model_name: timm model name string or pre-instantiated nn.Module
:param decoder_embed_dim: Decoder hidden dimension (default: 512)
:param decoder_depth: Number of decoder blocks (default: 8)
:param decoder_num_heads: Decoder attention heads (default: 16)
:param mask_ratio: Fraction of patches to mask (default: 0.75)
:param block_size: Masking block size, 1=random (default: 1)
:param norm_pix_loss: Normalize target pixels per patch (default: True)
:param loss_type: Loss type for MAELoss (default: 'mse')
:param pretrained: Load pretrained encoder weights
:param masking: Custom masking module (e.g., MultiBlockMasking).
When provided, overrides mask_ratio and block_size.
Example::
# Basic usage
model = MAE("vit_base_patch16_224", mask_ratio=0.75)
images = torch.randn(4, 3, 224, 224)
model.train()
output = model(images)
output.loss.backward()
model.eval()
output = model(images) # Full reconstruction, zero loss
Example with Lightning::
class MAELightning(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = MAE("vit_base_patch16_224")
def training_step(self, batch, batch_idx):
images = batch[0] if isinstance(batch, (list, tuple)) else batch
return self.model(images).loss
def configure_optimizers(self):
return torch.optim.AdamW(self.parameters(), lr=1.5e-4)
"""
def __init__(
self,
model_or_model_name: Union[str, nn.Module] = "vit_base_patch16_224",
decoder_embed_dim: int = 512,
decoder_depth: int = 8,
decoder_num_heads: int = 16,
mask_ratio: float = 0.75,
block_size: int = 1,
norm_pix_loss: bool = True,
loss_type: str = "mse",
pretrained: bool = False,
masking: Optional[nn.Module] = None,
):
super().__init__()
# Encoder with masking
if masking is not None:
self.masking = masking
else:
self.masking = PatchMasking(mask_ratio=mask_ratio, block_size=block_size)
self.encoder = MaskedEncoder(
model_or_model_name, masking=self.masking, pretrained=pretrained
)
embed_dim = self.encoder.embed_dim
num_patches = self.encoder.default_grid_h * self.encoder.default_grid_w
patch_size = self.encoder.patch_size_h
in_chans = self.encoder.patch_embed.proj.in_channels
patch_dim = patch_size * patch_size * in_chans
# Decoder
self.decoder = MAEDecoder(
embed_dim=embed_dim,
decoder_embed_dim=decoder_embed_dim,
output_dim=patch_dim,
num_patches=num_patches,
depth=decoder_depth,
num_heads=decoder_num_heads,
)
# Loss
self.loss_fn = MAELoss(
patch_size=patch_size,
loss_type=loss_type,
mask_only=True,
patch_normalize=norm_pix_loss,
)
[docs]
def forward(self, images: torch.Tensor) -> MAEOutput:
"""Forward pass.
Training: masks patches, encodes visible, decodes all, loss on masked.
Eval: no masking, full encode/decode, zero loss.
:param images: Input images [B, C, H, W]
:return: MAEOutput with loss and reconstructions
"""
enc_out = self.encoder(images)
# Decode (output_masked_only=False gives full reconstruction)
encoded_patches = enc_out.encoded[:, self.encoder.num_prefix_tokens :]
predictions = self.decoder(
encoded_patches,
enc_out.mask,
ids_keep=enc_out.ids_keep,
output_masked_only=False,
)
if self.training:
loss = self.loss_fn(predictions, images.to(predictions.dtype), enc_out.mask)
num_masked = int(enc_out.mask.sum(dim=1)[0].item())
else:
loss = torch.tensor(0.0, device=images.device)
num_masked = 0
return MAEOutput(
loss=loss,
predictions=predictions,
mask=enc_out.mask,
num_masked=num_masked,
num_visible=enc_out.mask.shape[1] - num_masked,
)