VICRegL

VICRegL#

class stable_pretraining.methods.VICRegL(encoder_name: str | Module = 'vit_small_patch16_224', projector_dim: int = 2048, sim_coeff: float = 25.0, std_coeff: float = 25.0, cov_coeff: float = 1.0, alpha: float = 0.75, image_size: int = 224, pretrained: bool = False)[source]#

Bases: Module

VICRegL: VICReg with an extra local-feature term.

Parameters:
  • encoder_name – timm ViT name (default "vit_small_patch16_224").

  • projector_dim – Output dim of both global and local projectors (default 2048).

  • sim_coeff – Invariance weight (default 25.0).

  • std_coeff – Variance weight (default 25.0).

  • cov_coeff – Covariance weight (default 1.0).

  • alpha – Mixing weight between global and local terms (default 0.75 means global gets 75%).

  • image_size – Input size (default 224).

  • pretrained – Load pretrained timm weights.

forward(view1: Tensor, view2: Tensor | None = None) VICRegLOutput[source]#

Same as torch.nn.Module.forward().

Parameters:
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

Returns:

Your model’s output