Source code for stable_pretraining.methods.barlow_twins

"""Barlow Twins: Self-Supervised Learning via Redundancy Reduction.

Drives the cross-correlation matrix between projected embeddings of two
augmented views toward the identity, decorrelating features while staying
invariant to augmentations.

References:
    Zbontar, Jing, Misra, LeCun, Deny. "Barlow Twins: Self-Supervised
    Learning via Redundancy Reduction." ICML 2021.
    https://arxiv.org/abs/2103.03230
"""

from dataclasses import dataclass
from typing import Optional, Sequence, Union

import torch
import torch.nn as nn
from transformers.utils import ModelOutput

from stable_pretraining import Module
from stable_pretraining.backbone import from_timm
from stable_pretraining.losses import BarlowTwinsLoss


@dataclass
class BarlowTwinsOutput(ModelOutput):
    """Output from Barlow Twins forward pass."""

    loss: torch.Tensor = None
    embedding: torch.Tensor = None
    projection: Optional[torch.Tensor] = None


def _build_barlow_projector(in_dim: int, hidden_dims: Sequence[int]) -> nn.Module:
    """Linear(no-bias) -> BN -> ReLU stacked layers ending in a bias-free Linear."""
    if len(hidden_dims) < 1:
        raise ValueError("hidden_dims must contain at least one entry")
    layers = []
    prev = in_dim
    for i, dim in enumerate(hidden_dims):
        is_last = i == len(hidden_dims) - 1
        layers.append(nn.Linear(prev, dim, bias=False))
        if not is_last:
            layers.append(nn.BatchNorm1d(dim))
            layers.append(nn.ReLU(inplace=True))
        prev = dim
    return nn.Sequential(*layers)


[docs] class BarlowTwins(Module): """Barlow Twins self-supervised learning. :param encoder_name: timm model name or pre-built ``nn.Module``. :param projector_dims: Hidden + output dims (default ``(8192, 8192, 8192)`` matches the ResNet50 recipe). :param lambd: Off-diagonal weight in the cross-correlation loss (default ``5.1e-3`` from the paper). :param low_resolution: Adapt first conv for low-res input. :param pretrained: Load pretrained timm weights. """ def __init__( self, encoder_name: Union[str, nn.Module] = "vit_small_patch16_224", projector_dims: Sequence[int] = (8192, 8192, 8192), lambd: float = 5.1e-3, low_resolution: bool = False, pretrained: bool = False, ): super().__init__() if isinstance(encoder_name, str): self.backbone = from_timm( encoder_name, num_classes=0, low_resolution=low_resolution, pretrained=pretrained, ) else: self.backbone = encoder_name with torch.no_grad(): embed_dim = self.backbone(torch.zeros(1, 3, 224, 224)).shape[-1] self.embed_dim = embed_dim self.projector = _build_barlow_projector(embed_dim, list(projector_dims)) self.barlow_loss = BarlowTwinsLoss(lambd=lambd)
[docs] def forward( self, view1: torch.Tensor, view2: Optional[torch.Tensor] = None, ) -> BarlowTwinsOutput: if view2 is None: embedding = self.backbone(view1) return BarlowTwinsOutput( loss=torch.zeros((), device=embedding.device, dtype=embedding.dtype), embedding=embedding, projection=None, ) h1 = self.backbone(view1) h2 = self.backbone(view2) z1 = self.projector(h1) z2 = self.projector(h2) loss = self.barlow_loss(z1, z2) return BarlowTwinsOutput( loss=loss, embedding=torch.cat([h1, h2], dim=0), projection=torch.cat([z1, z2], dim=0), )