PIRL

Contents

PIRL#

class stable_pretraining.methods.PIRL(encoder_name: str | Module = 'vit_small_patch16_224', projector_dim: int = 128, queue_length: int = 16384, temperature: float = 0.07, lambda_pirl: float = 0.5, jigsaw_grid: int = 4, low_resolution: bool = False, pretrained: bool = False)[source]#

Bases: Module

PIRL: jigsaw-invariant memory-bank SSL.

Parameters:
  • encoder_name – timm model name or pre-built nn.Module.

  • projector_dim – Output projection dim (default 128).

  • queue_length – Memory bank size (default 16384; paper used full dataset, but a queue works as an approximation).

  • temperature – NCE temperature (default 0.07).

  • lambda_pirl – Weight on the (jigsaw, original) loss vs (jigsaw, shuffled-elsewhere) (default 0.5).

  • jigsaw_grid – Grid size for the jigsaw transform (default 3).

  • low_resolution – Adapt first conv for low-res input.

  • pretrained – Load pretrained timm weights.

forward(view1: Tensor, view2: Tensor | None = None) PIRLOutput[source]#

Same as torch.nn.Module.forward().

Parameters:
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

Returns:

Your model’s output