SwAV

Contents

SwAV#

class stable_pretraining.methods.SwAV(encoder_name: str | Module = 'vit_small_patch16_224', projector_dims: Sequence[int] = (2048, 128), n_prototypes: int = 3000, temperature: float = 0.1, sinkhorn_iterations: int = 3, epsilon: float = 0.05, low_resolution: bool = False, pretrained: bool = False, dynamic_img_size: bool = True)[source]#

Bases: Module

SwAV: prototype-based online clustering for SSL.

Parameters:
  • encoder_name – timm model name or pre-built nn.Module.

  • projector_dims(hidden, output) for the projector (default (2048, 128)).

  • n_prototypes – Number of prototypes (default 3000).

  • temperature – Temperature for the swapped-prediction softmax (default 0.1).

  • sinkhorn_iterations – Sinkhorn iterations (default 3).

  • epsilon – Sinkhorn entropy coefficient (default 0.05).

  • low_resolution – Adapt first conv for low-res input.

  • pretrained – Load pretrained timm weights.

forward(view1: Tensor | None = None, view2: Tensor | None = None, global_views: Sequence[Tensor] | None = None, local_views: Sequence[Tensor] | None = None, images: Tensor | None = None) SwAVOutput[source]#

SwAV forward.

Three calling conventions:

  • forward(view1, view2) — original 2-view (no multi-crop).

  • forward(global_views=[...], local_views=[...]) — full multi-crop.

  • forward(images=...) — eval / single-image embedding extraction.