Source code for stable_pretraining.methods.byol

"""BYOL: Bootstrap Your Own Latent.

Self-supervised learning without negative pairs. The student network has a
backbone, a projector, and a predictor; the teacher (EMA of student) has
only a backbone and projector. The student predicts the teacher's projection
of the other view; the loss is symmetric MSE between L2-normalised vectors.

References:
    Grill et al. "Bootstrap Your Own Latent: A New Approach to
    Self-Supervised Learning." NeurIPS 2020.
    https://arxiv.org/abs/2006.07733
"""

from dataclasses import dataclass
from typing import Optional, Sequence, Union

import torch
import torch.nn as nn
from transformers.utils import ModelOutput

from stable_pretraining import Module
from stable_pretraining.backbone import TeacherStudentWrapper, from_timm
from stable_pretraining.losses import BYOLLoss


@dataclass
class BYOLOutput(ModelOutput):
    """Output from BYOL forward pass.

    :ivar loss: Symmetric BYOL loss (0 in eval mode)
    :ivar embedding: Teacher backbone features (always detached)
    :ivar prediction: Student predictor output [2B, P] (None in eval)
    :ivar target: Teacher projector output [2B, P] (None in eval, detached)
    """

    loss: torch.Tensor = None
    embedding: torch.Tensor = None
    prediction: Optional[torch.Tensor] = None
    target: Optional[torch.Tensor] = None


def _byol_mlp(in_dim: int, hidden_dim: int, out_dim: int) -> nn.Module:
    """2-layer Linear-BN-ReLU-Linear MLP used for projector and predictor."""
    return nn.Sequential(
        nn.Linear(in_dim, hidden_dim),
        nn.BatchNorm1d(hidden_dim),
        nn.ReLU(inplace=True),
        nn.Linear(hidden_dim, out_dim),
    )


[docs] class BYOL(Module): """BYOL self-supervised learning with EMA target network. :param encoder_name: timm model name or pre-built ``nn.Module``. :param projector_dims: ``(hidden, output)`` for the 2-layer projector (default ``(4096, 256)``). :param predictor_dims: ``(hidden, output)`` for the predictor (default ``(4096, 256)``). :param ema_decay_start: Initial EMA decay (default 0.99). :param ema_decay_end: Final EMA decay (default 1.0). :param low_resolution: Adapt first conv for low-res input. :param pretrained: Load pretrained timm weights. Note: Use :class:`~stable_pretraining.callbacks.TeacherStudentCallback` to drive teacher EMA updates during training. """ def __init__( self, encoder_name: Union[str, nn.Module] = "vit_small_patch16_224", projector_dims: Sequence[int] = (4096, 256), predictor_dims: Sequence[int] = (4096, 256), ema_decay_start: float = 0.99, ema_decay_end: float = 1.0, low_resolution: bool = False, pretrained: bool = False, ): super().__init__() if isinstance(encoder_name, str): base_backbone = from_timm( encoder_name, num_classes=0, low_resolution=low_resolution, pretrained=pretrained, ) else: base_backbone = encoder_name with torch.no_grad(): embed_dim = base_backbone(torch.zeros(1, 3, 224, 224)).shape[-1] self.embed_dim = embed_dim if len(projector_dims) != 2 or len(predictor_dims) != 2: raise ValueError( "projector_dims and predictor_dims must be (hidden, output) tuples" ) proj_hidden, proj_out = projector_dims pred_hidden, pred_out = predictor_dims if pred_out != proj_out: raise ValueError( f"predictor output dim ({pred_out}) must match projector output dim ({proj_out})" ) self.backbone = TeacherStudentWrapper( base_backbone, warm_init=True, base_ema_coefficient=ema_decay_start, final_ema_coefficient=ema_decay_end, ) self.projector = TeacherStudentWrapper( _byol_mlp(embed_dim, proj_hidden, proj_out), warm_init=True, base_ema_coefficient=ema_decay_start, final_ema_coefficient=ema_decay_end, ) self.predictor = _byol_mlp(proj_out, pred_hidden, pred_out) self.byol_loss = BYOLLoss()
[docs] def forward( self, view1: torch.Tensor, view2: Optional[torch.Tensor] = None, ) -> BYOLOutput: if view2 is None: with torch.no_grad(): embedding = self.backbone.forward_teacher(view1) return BYOLOutput( loss=torch.zeros((), device=embedding.device, dtype=embedding.dtype), embedding=embedding.detach(), prediction=None, target=None, ) # Student path: backbone -> projector -> predictor s1 = self.backbone.forward_student(view1) s2 = self.backbone.forward_student(view2) zs1 = self.projector.forward_student(s1) zs2 = self.projector.forward_student(s2) p1 = self.predictor(zs1) p2 = self.predictor(zs2) # Teacher path: detached, EMA target with torch.no_grad(): t1 = self.backbone.forward_teacher(view1) t2 = self.backbone.forward_teacher(view2) zt1 = self.projector.forward_teacher(t1) zt2 = self.projector.forward_teacher(t2) loss = (self.byol_loss(p1, zt2) + self.byol_loss(p2, zt1)) / 2 return BYOLOutput( loss=loss, embedding=torch.cat([t1, t2], dim=0).detach(), prediction=torch.cat([p1, p2], dim=0), target=torch.cat([zt1, zt2], dim=0).detach(), )