SwAV#
- class stable_pretraining.methods.SwAV(encoder_name: str | Module = 'vit_small_patch16_224', projector_dims: Sequence[int] = (2048, 128), n_prototypes: int = 3000, temperature: float = 0.1, sinkhorn_iterations: int = 3, epsilon: float = 0.05, low_resolution: bool = False, pretrained: bool = False, dynamic_img_size: bool = True)[source]#
Bases:
ModuleSwAV: prototype-based online clustering for SSL.
- Parameters:
encoder_name – timm model name or pre-built
nn.Module.projector_dims –
(hidden, output)for the projector (default(2048, 128)).n_prototypes – Number of prototypes (default 3000).
temperature – Temperature for the swapped-prediction softmax (default 0.1).
sinkhorn_iterations – Sinkhorn iterations (default 3).
epsilon – Sinkhorn entropy coefficient (default 0.05).
low_resolution – Adapt first conv for low-res input.
pretrained – Load pretrained timm weights.
- forward(view1: Tensor | None = None, view2: Tensor | None = None, global_views: Sequence[Tensor] | None = None, local_views: Sequence[Tensor] | None = None, images: Tensor | None = None) SwAVOutput[source]#
SwAV forward.
Three calling conventions:
forward(view1, view2)— original 2-view (no multi-crop).forward(global_views=[...], local_views=[...])— full multi-crop.forward(images=...)— eval / single-image embedding extraction.