Source code for stable_pretraining.methods.wmse

"""W-MSE: Whitening Mean-Squared Error.

Joint-embedding SSL where projections are batch-whitened (Cholesky) and
the loss is the MSE between the whitened projections of the two views.
Whitening removes second-order redundancy without negatives or stop-gradient.

References:
    Ermolov et al. "Whitening for Self-Supervised Representation Learning."
    ICML 2021. https://arxiv.org/abs/2007.06346
"""

from dataclasses import dataclass
from typing import Optional, Sequence, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.utils import ModelOutput

from stable_pretraining import Module
from stable_pretraining.backbone import from_timm


@dataclass
class WMSEOutput(ModelOutput):
    """Structured output of the :class:`WMSE` SSL method."""

    loss: torch.Tensor = None
    embedding: torch.Tensor = None
    projection: Optional[torch.Tensor] = None


def _projector(in_dim: int, hidden_dim: int, out_dim: int) -> nn.Module:
    return nn.Sequential(
        nn.Linear(in_dim, hidden_dim),
        nn.BatchNorm1d(hidden_dim),
        nn.ReLU(inplace=True),
        nn.Linear(hidden_dim, out_dim),
    )


def _eigen_whiten(z: torch.Tensor, eps: float = 1e-3) -> torch.Tensor:
    """ZCA-style whiten ``[B, D]``.

    Uses the symmetric eigendecomposition of the (B-centred) covariance.
    Forced into fp32 (``eigh`` has no fp16 kernel).
    """
    orig_dtype = z.dtype
    # Disable autocast so the float() cast actually sticks.
    with torch.amp.autocast(device_type=z.device.type, enabled=False):
        z32 = z.float()
        z32 = z32 - z32.mean(dim=0, keepdim=True)
        cov = (z32.T @ z32) / max(z32.shape[0] - 1, 1)
        cov = (cov + cov.T) * 0.5
        eigvals, eigvecs = torch.linalg.eigh(cov)
        eigvals = eigvals.clamp(min=eps)
        inv_sqrt = eigvecs @ torch.diag(eigvals.rsqrt()) @ eigvecs.T
        out = z32 @ inv_sqrt
    return out.to(orig_dtype)


# Backward-compat alias
_cholesky_whiten = _eigen_whiten


[docs] class WMSE(Module): """W-MSE: whitening + MSE between paired views. :param encoder_name: timm model name or pre-built ``nn.Module``. :param projector_dims: ``(hidden, output)`` for the projector (default ``(1024, 64)``; a small whitening dim helps stability). :param eps: Cholesky regularisation (default ``1e-3``). :param low_resolution: Adapt first conv for low-res input. :param pretrained: Load pretrained timm weights. """ def __init__( self, encoder_name: Union[str, nn.Module] = "vit_small_patch16_224", projector_dims: Sequence[int] = (1024, 64), eps: float = 1e-3, low_resolution: bool = False, pretrained: bool = False, ): super().__init__() if isinstance(encoder_name, str): self.backbone = from_timm( encoder_name, num_classes=0, low_resolution=low_resolution, pretrained=pretrained, ) else: self.backbone = encoder_name with torch.no_grad(): embed_dim = self.backbone(torch.zeros(1, 3, 224, 224)).shape[-1] self.embed_dim = embed_dim proj_hidden, proj_out = projector_dims self.projector = _projector(embed_dim, proj_hidden, proj_out) self.eps = eps
[docs] def forward( self, view1: torch.Tensor, view2: Optional[torch.Tensor] = None, ) -> WMSEOutput: if view2 is None: embedding = self.backbone(view1) return WMSEOutput( loss=torch.zeros((), device=embedding.device, dtype=embedding.dtype), embedding=embedding, ) h1 = self.backbone(view1) h2 = self.backbone(view2) z1 = self.projector(h1) z2 = self.projector(h2) # Whiten the *concatenated* batch jointly so both views see the same # statistics (matches the paper). z = torch.cat([z1, z2], dim=0) zw = _cholesky_whiten(z, eps=self.eps) zw = F.normalize(zw, dim=-1) zw1, zw2 = zw.chunk(2, dim=0) loss = F.mse_loss(zw1, zw2) return WMSEOutput( loss=loss, embedding=torch.cat([h1, h2], dim=0), projection=z, )