MAE#
- class stable_pretraining.methods.MAE(model_or_model_name: str | Module = 'vit_base_patch16_224', decoder_embed_dim: int = 512, decoder_depth: int = 8, decoder_num_heads: int = 16, mask_ratio: float = 0.75, block_size: int = 1, norm_pix_loss: bool = True, loss_type: str = 'mse', pretrained: bool = False, masking: Module | None = None)[source]#
Bases:
ModuleMAE: Masked Autoencoders Are Scalable Vision Learners.
- Architecture:
Encoder: ViT processing only visible (unmasked) patches
Decoder: Lightweight transformer reconstructing masked patches
Target: Normalized pixel values of masked patches
- Parameters:
model_or_model_name – timm model name string or pre-instantiated nn.Module
decoder_embed_dim – Decoder hidden dimension (default: 512)
decoder_depth – Number of decoder blocks (default: 8)
decoder_num_heads – Decoder attention heads (default: 16)
mask_ratio – Fraction of patches to mask (default: 0.75)
block_size – Masking block size, 1=random (default: 1)
norm_pix_loss – Normalize target pixels per patch (default: True)
loss_type – Loss type for MAELoss (default: ‘mse’)
pretrained – Load pretrained encoder weights
masking – Custom masking module (e.g., MultiBlockMasking). When provided, overrides mask_ratio and block_size.
Example:
# Basic usage model = MAE("vit_base_patch16_224", mask_ratio=0.75) images = torch.randn(4, 3, 224, 224) model.train() output = model(images) output.loss.backward() model.eval() output = model(images) # Full reconstruction, zero loss
Example with Lightning:
class MAELightning(pl.LightningModule): def __init__(self): super().__init__() self.model = MAE("vit_base_patch16_224") def training_step(self, batch, batch_idx): images = batch[0] if isinstance(batch, (list, tuple)) else batch return self.model(images).loss def configure_optimizers(self): return torch.optim.AdamW(self.parameters(), lr=1.5e-4)