"""BEiT: BERT Pre-Training of Image Transformers.
Masked image modeling where the encoder predicts a *discrete* visual-token
ID at every masked patch. The original paper uses a frozen DALL-E
discrete VAE as the tokenizer; BEiT v2 distils a pretrained CLIP via a
VQ-VAE; BEiT v3 uses a multi-modal MoE.
This implementation is tokenizer-agnostic: pass any callable that maps
``[B, C, H, W]`` images to ``[B, N]`` ``int64`` token IDs in
``[0, vocab_size)``. No tokenizer is bundled to keep the dependency
surface small. A simple in-repo placeholder is provided as
:func:`patch_kmeans_tokenizer` for testing.
References:
Bao et al. "BEiT: BERT Pre-Training of Image Transformers."
ICLR 2022. https://arxiv.org/abs/2106.08254
"""
from dataclasses import dataclass
from typing import Callable, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.utils import ModelOutput
from stable_pretraining import Module
from stable_pretraining.backbone import patchify
@dataclass
class BEiTOutput(ModelOutput):
"""Structured output of the :class:`BEiT` SSL method."""
loss: torch.Tensor = None
embedding: torch.Tensor = None
logits: Optional[torch.Tensor] = None
targets: Optional[torch.Tensor] = None
mask: Optional[torch.Tensor] = None
def patch_kmeans_tokenizer(vocab_size: int, patch_size: int, in_channels: int = 3):
"""Return a dummy tokenizer that hashes flattened patches into buckets.
Uses a fixed random projection to ``vocab_size`` buckets. Useful for
end-to-end smoke tests only — replace with a DALL-E or VQ-VAE
tokenizer for real training.
"""
proj = None
def _tok(images: torch.Tensor) -> torch.Tensor:
nonlocal proj
patches = patchify(images, (in_channels, patch_size, patch_size)) # [B, N, P]
if proj is None or proj.shape[0] != patches.shape[-1]:
g = torch.Generator(device="cpu").manual_seed(0)
proj = torch.randn(patches.shape[-1], vocab_size, generator=g).to(
patches.device
)
scores = patches.float() @ proj.to(patches.device)
return scores.argmax(dim=-1) # [B, N]
return _tok
[docs]
class BEiT(Module):
"""BEiT masked image modeling with a discrete visual tokenizer.
:param encoder_name: timm ViT name (default ``"vit_small_patch16_224"``).
:param tokenizer: Callable ``images -> [B, N] int64`` returning visual
token IDs. If ``None``, defaults to :func:`patch_kmeans_tokenizer`
(placeholder; not SOTA).
:param vocab_size: Number of visual tokens (default 8192, matches DALL-E).
:param patch_size: Patch size of the encoder (default 16).
:param mask_ratio: Fraction of patches masked (default 0.4, BEiT used 0.4).
:param image_size: Input size (default 224).
:param pretrained: Load pretrained timm weights.
"""
def __init__(
self,
encoder_name: Union[str, nn.Module] = "vit_small_patch16_224",
tokenizer: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
vocab_size: int = 8192,
patch_size: int = 16,
mask_ratio: float = 0.4,
image_size: int = 224,
pretrained: bool = False,
):
super().__init__()
if isinstance(encoder_name, str):
import timm
self.encoder = timm.create_model(
encoder_name, num_classes=0, pretrained=pretrained
)
else:
self.encoder = encoder_name
with torch.no_grad():
seq = self.encoder.forward_features(
torch.zeros(1, 3, image_size, image_size)
)
self._has_cls = (
hasattr(self.encoder, "cls_token")
and self.encoder.cls_token is not None
and seq.shape[1] > 1
)
embed_dim = seq.shape[-1]
self.embed_dim = embed_dim
self.vocab_size = vocab_size
self.patch_size = patch_size
self.mask_ratio = mask_ratio
self.image_size = image_size
self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
nn.init.trunc_normal_(self.mask_token, std=0.02)
# Predict visual-token IDs at each (masked) patch position.
self.head = nn.Linear(embed_dim, vocab_size)
self.tokenizer = tokenizer or patch_kmeans_tokenizer(
vocab_size=vocab_size, patch_size=patch_size
)
def _encode_with_mask(
self, images: torch.Tensor, mask: torch.Tensor
) -> torch.Tensor:
vit = self.encoder
x = vit.patch_embed(images)
m = mask.unsqueeze(-1)
x = x * (1 - m) + self.mask_token.expand_as(x) * m
if self._has_cls:
cls = vit.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat([cls, x], dim=1)
x = x + vit.pos_embed
x = vit.pos_drop(x)
x = vit.blocks(x)
x = vit.norm(x)
if self._has_cls:
x = x[:, 1:]
return x
def _random_mask(self, B, N, device):
n_mask = int(round(N * self.mask_ratio))
noise = torch.rand(B, N, device=device)
order = noise.argsort(dim=1)
mask = torch.zeros(B, N, device=device)
mask.scatter_(1, order[:, :n_mask], 1.0)
return mask
[docs]
def forward(self, images: torch.Tensor) -> BEiTOutput:
B = images.shape[0]
if not self.training:
features = self.encoder.forward_features(images)
cls = features[:, 0] if self._has_cls else features.mean(dim=1)
return BEiTOutput(
loss=torch.zeros((), device=images.device, dtype=images.dtype),
embedding=cls,
)
# Tokenize the full image (no augmentation needed for token targets).
with torch.no_grad():
targets = self.tokenizer(images) # [B, N]
N = targets.shape[1]
mask = self._random_mask(B, N, device=images.device)
encoded = self._encode_with_mask(images, mask) # [B, N, D]
logits = self.head(encoded) # [B, N, V]
# Cross-entropy over masked positions only.
flat_mask = mask.bool().view(-1)
flat_logits = logits.reshape(-1, self.vocab_size)[flat_mask]
flat_targets = targets.reshape(-1)[flat_mask]
loss = F.cross_entropy(flat_logits, flat_targets)
# Embedding for online probes: mean of unmasked encoded tokens.
unmasked = (1 - mask).unsqueeze(-1)
embedding = (encoded * unmasked).sum(dim=1) / unmasked.sum(dim=1).clamp(min=1.0)
return BEiTOutput(
loss=loss,
embedding=embedding.detach(),
logits=logits,
targets=targets,
mask=mask,
)