SimSiam

SimSiam#

class stable_pretraining.methods.SimSiam(encoder_name: str | Module = 'vit_small_patch16_224', projector_dim: int = 2048, predictor_hidden_dim: int = 512, low_resolution: bool = False, pretrained: bool = False)[source]#

Bases: Module

SimSiam: simple siamese SSL with stop-gradient.

Parameters:
  • encoder_name – timm model name or pre-built nn.Module.

  • projector_dim – Projector hidden + output dim (default 2048).

  • predictor_hidden_dim – Predictor bottleneck dim (default 512).

  • low_resolution – Adapt first conv for 32x32.

  • pretrained – Load pretrained timm weights.

forward(view1: Tensor, view2: Tensor | None = None) SimSiamOutput[source]#

Same as torch.nn.Module.forward().

Parameters:
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

Returns:

Your model’s output