MAE

Contents

MAE#

class stable_pretraining.methods.MAE(model_or_model_name: str | Module = 'vit_base_patch16_224', decoder_embed_dim: int = 512, decoder_depth: int = 8, decoder_num_heads: int = 16, mask_ratio: float = 0.75, block_size: int = 1, norm_pix_loss: bool = True, loss_type: str = 'mse', pretrained: bool = False, masking: Module | None = None)[source]#

Bases: Module

MAE: Masked Autoencoders Are Scalable Vision Learners.

Architecture:
  • Encoder: ViT processing only visible (unmasked) patches

  • Decoder: Lightweight transformer reconstructing masked patches

  • Target: Normalized pixel values of masked patches

Parameters:
  • model_or_model_name – timm model name string or pre-instantiated nn.Module

  • decoder_embed_dim – Decoder hidden dimension (default: 512)

  • decoder_depth – Number of decoder blocks (default: 8)

  • decoder_num_heads – Decoder attention heads (default: 16)

  • mask_ratio – Fraction of patches to mask (default: 0.75)

  • block_size – Masking block size, 1=random (default: 1)

  • norm_pix_loss – Normalize target pixels per patch (default: True)

  • loss_type – Loss type for MAELoss (default: ‘mse’)

  • pretrained – Load pretrained encoder weights

  • masking – Custom masking module (e.g., MultiBlockMasking). When provided, overrides mask_ratio and block_size.

Example:

# Basic usage
model = MAE("vit_base_patch16_224", mask_ratio=0.75)
images = torch.randn(4, 3, 224, 224)

model.train()
output = model(images)
output.loss.backward()

model.eval()
output = model(images)  # Full reconstruction, zero loss

Example with Lightning:

class MAELightning(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = MAE("vit_base_patch16_224")

    def training_step(self, batch, batch_idx):
        images = batch[0] if isinstance(batch, (list, tuple)) else batch
        return self.model(images).loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=1.5e-4)
forward(images: Tensor) MAEOutput[source]#

Forward pass.

Training: masks patches, encodes visible, decodes all, loss on masked. Eval: no masking, full encode/decode, zero loss.

Parameters:

images – Input images [B, C, H, W]

Returns:

MAEOutput with loss and reconstructions