BYOL

Contents

BYOL#

class stable_pretraining.methods.BYOL(encoder_name: str | Module = 'vit_small_patch16_224', projector_dims: Sequence[int] = (4096, 256), predictor_dims: Sequence[int] = (4096, 256), ema_decay_start: float = 0.99, ema_decay_end: float = 1.0, low_resolution: bool = False, pretrained: bool = False)[source]#

Bases: Module

BYOL self-supervised learning with EMA target network.

Parameters:
  • encoder_name – timm model name or pre-built nn.Module.

  • projector_dims(hidden, output) for the 2-layer projector (default (4096, 256)).

  • predictor_dims(hidden, output) for the predictor (default (4096, 256)).

  • ema_decay_start – Initial EMA decay (default 0.99).

  • ema_decay_end – Final EMA decay (default 1.0).

  • low_resolution – Adapt first conv for low-res input.

  • pretrained – Load pretrained timm weights.

Note

Use TeacherStudentCallback to drive teacher EMA updates during training.

forward(view1: Tensor, view2: Tensor | None = None) BYOLOutput[source]#

Same as torch.nn.Module.forward().

Parameters:
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

Returns:

Your model’s output