MSN

Contents

MSN#

class stable_pretraining.methods.MSN(encoder_name: str | Module = 'vit_small_patch16_224', projector_hidden_dim: int = 2048, projector_bottleneck_dim: int = 256, n_prototypes: int = 1024, mask_ratio: float = 0.6, temperature_student: float = 0.1, temperature_teacher: float = 0.025, me_max_weight: float = 1.0, ema_decay_start: float = 0.996, ema_decay_end: float = 1.0, image_size: int = 224, pretrained: bool = False)[source]#

Bases: Module

MSN: masked siamese DINO-style SSL.

Parameters:
  • encoder_name – timm ViT name (default "vit_small_patch16_224").

  • projector_hidden_dim – Hidden dim (default 2048).

  • projector_bottleneck_dim – Bottleneck dim (default 256).

  • n_prototypes – Prototype count (default 1024).

  • mask_ratio – Patch mask ratio for the student (default 0.6).

  • temperature_student – Student softmax temperature (default 0.1).

  • temperature_teacher – Teacher softmax temperature (default 0.025).

  • me_max_weight – Mean-entropy maximisation weight (default 1.0).

  • ema_decay_start – Initial backbone/head EMA (default 0.996).

  • ema_decay_end – Final EMA (default 1.0).

  • image_size – Input size (default 224).

  • pretrained – Load pretrained timm weights.

forward(view1: Tensor | None = None, view2: Tensor | None = None, images: Tensor | None = None) MSNOutput[source]#

Same as torch.nn.Module.forward().

Parameters:
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

Returns:

Your model’s output