VICReg

VICReg#

class stable_pretraining.methods.VICReg(encoder_name: str | Module = 'vit_small_patch16_224', projector_dims: Sequence[int] = (8192, 8192, 8192), sim_coeff: float = 25.0, std_coeff: float = 25.0, cov_coeff: float = 1.0, low_resolution: bool = False, pretrained: bool = False)[source]#

Bases: Module

VICReg: variance-invariance-covariance self-supervised learning.

Parameters:
  • encoder_name – timm model name or pre-built nn.Module.

  • projector_dims – Hidden + output dims for the projector. Default (8192, 8192, 8192) matches the ResNet50 paper recipe.

  • sim_coeff – Invariance term weight (default 25.0).

  • std_coeff – Variance term weight (default 25.0).

  • cov_coeff – Covariance term weight (default 1.0).

  • low_resolution – Adapt first conv for 32x32 inputs (CIFAR-style).

  • pretrained – Load pretrained timm weights for the encoder.

forward(view1: Tensor, view2: Tensor | None = None) VICRegOutput[source]#

Same as torch.nn.Module.forward().

Parameters:
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

Returns:

Your model’s output