CMAE

Contents

CMAE#

class stable_pretraining.methods.CMAE(encoder_name: str | Module = 'vit_small_patch16_224', patch_size: int = 16, mask_ratio: float = 0.75, projector_dim: int = 256, contrast_weight: float = 1.0, ema_decay_start: float = 0.99, ema_decay_end: float = 1.0, image_size: int = 224, in_channels: int = 3, pretrained: bool = False)[source]#

Bases: Module

CMAE: MAE pixel loss + EMA contrastive loss.

Parameters:
  • encoder_name – timm ViT name (default "vit_small_patch16_224").

  • patch_size – Patch size (default 16).

  • mask_ratio – Mask ratio (default 0.75, as in MAE).

  • projector_dim – Contrastive projector hidden/out dim (default 256).

  • contrast_weight – Weight on the contrastive term (default 1.0).

  • ema_decay_start – Initial EMA (default 0.99).

  • ema_decay_end – Final EMA (default 1.0).

  • image_size – Input size (default 224).

  • in_channels – Channels (default 3).

  • pretrained – Load pretrained timm weights.

forward(view1: Tensor, view2: Tensor | None = None) CMAEOutput[source]#

Same as torch.nn.Module.forward().

Parameters:
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

Returns:

Your model’s output