Source code for stable_pretraining.methods.pirl

"""PIRL: Pretext-Invariant Representation Learning.

NCE-based memory-bank method that trains the encoder so that an image and
its *jigsaw-shuffled* version share the same representation. The original
image embedding goes into the memory bank; the shuffled-patch embedding is
the query and matched against the bank as positive.

References:
    Misra, van der Maaten. "Self-Supervised Learning of Pretext-Invariant
    Representations." CVPR 2020. https://arxiv.org/abs/1912.01991
"""

from dataclasses import dataclass
from typing import Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.utils import ModelOutput

from stable_pretraining import Module
from stable_pretraining.backbone import from_timm
from stable_pretraining.callbacks.queues import UnsortedQueue


@dataclass
class PIRLOutput(ModelOutput):
    """Structured output of the :class:`PIRL` SSL method."""

    loss: torch.Tensor = None
    embedding: torch.Tensor = None


def _shuffle_patches(images: torch.Tensor, grid: int = 4) -> torch.Tensor:
    """Crop each image into a ``grid x grid`` jigsaw and randomly permute it.

    The image is resized so its side length is divisible by ``grid`` and
    then resized back, so the encoder always sees its expected size.
    """
    B, C, H, W = images.shape
    if H != W:
        raise ValueError("PIRL jigsaw assumes square images")
    target_h = (H // grid) * grid
    if target_h != H:
        images_resized = F.interpolate(
            images, size=target_h, mode="bilinear", align_corners=False
        )
    else:
        images_resized = images
    p = target_h // grid
    x = images_resized.unfold(2, p, p).unfold(3, p, p)  # [B, C, grid, grid, p, p]
    x = x.permute(0, 2, 3, 1, 4, 5).contiguous()
    x = x.view(B, grid * grid, C, p, p)
    perm = torch.stack(
        [torch.randperm(grid * grid, device=images.device) for _ in range(B)]
    )
    x = torch.gather(x, 1, perm[..., None, None, None].expand_as(x))
    x = x.view(B, grid, grid, C, p, p).permute(0, 3, 1, 4, 2, 5).contiguous()
    out = x.view(B, C, target_h, target_h)
    if target_h != H:
        out = F.interpolate(out, size=H, mode="bilinear", align_corners=False)
    return out


[docs] class PIRL(Module): """PIRL: jigsaw-invariant memory-bank SSL. :param encoder_name: timm model name or pre-built ``nn.Module``. :param projector_dim: Output projection dim (default 128). :param queue_length: Memory bank size (default 16384; paper used full dataset, but a queue works as an approximation). :param temperature: NCE temperature (default 0.07). :param lambda_pirl: Weight on the (jigsaw, original) loss vs (jigsaw, shuffled-elsewhere) (default 0.5). :param jigsaw_grid: Grid size for the jigsaw transform (default 3). :param low_resolution: Adapt first conv for low-res input. :param pretrained: Load pretrained timm weights. """ def __init__( self, encoder_name: Union[str, nn.Module] = "vit_small_patch16_224", projector_dim: int = 128, queue_length: int = 16384, temperature: float = 0.07, lambda_pirl: float = 0.5, jigsaw_grid: int = 4, low_resolution: bool = False, pretrained: bool = False, ): super().__init__() if isinstance(encoder_name, str): self.backbone = from_timm( encoder_name, num_classes=0, low_resolution=low_resolution, pretrained=pretrained, ) else: self.backbone = encoder_name with torch.no_grad(): embed_dim = self.backbone(torch.zeros(1, 3, 224, 224)).shape[-1] self.embed_dim = embed_dim self.temperature = temperature self.lambda_pirl = lambda_pirl self.jigsaw_grid = jigsaw_grid # Two projection heads: one for original images, one for the jigsaw. self.proj_image = nn.Linear(embed_dim, projector_dim, bias=False) self.proj_jigsaw = nn.Linear(embed_dim, projector_dim, bias=False) self.queue = UnsortedQueue( max_length=queue_length, shape=(projector_dim,), dtype=torch.float32 )
[docs] def forward( self, view1: torch.Tensor, view2: Optional[torch.Tensor] = None, ) -> PIRLOutput: # Eval / single-image if view2 is None: embedding = self.backbone(view1) return PIRLOutput( loss=torch.zeros((), device=embedding.device, dtype=embedding.dtype), embedding=embedding, ) # ``view1`` is treated as the original image; ``view2`` is ignored # and replaced by a jigsaw of view1 (PIRL is single-augmentation). jigsaw = _shuffle_patches(view1, grid=self.jigsaw_grid) h_img = self.backbone(view1) h_jig = self.backbone(jigsaw) z_img = F.normalize(self.proj_image(h_img), dim=-1) z_jig = F.normalize(self.proj_jigsaw(h_jig), dim=-1) with torch.no_grad(): queue_keys = self.queue.append(z_img.detach().to(torch.float32)).to( z_img.dtype ) # NCE 1: jigsaw vs (positive=image, negatives=queue) pos1 = (z_jig * z_img.detach()).sum(dim=-1, keepdim=True) neg1 = z_jig @ queue_keys.T logits1 = torch.cat([pos1, neg1], dim=1) / self.temperature tgt = torch.zeros(z_jig.shape[0], dtype=torch.long, device=z_jig.device) loss1 = F.cross_entropy(logits1, tgt) # NCE 2: image vs (positive=image-itself in queue, negatives=other queue) pos2 = (z_img * z_img.detach()).sum(dim=-1, keepdim=True) neg2 = z_img @ queue_keys.T logits2 = torch.cat([pos2, neg2], dim=1) / self.temperature loss2 = F.cross_entropy(logits2, tgt) loss = self.lambda_pirl * loss1 + (1 - self.lambda_pirl) * loss2 return PIRLOutput(loss=loss, embedding=h_img)