Source code for stable_pretraining.backbone.vit
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, Literal, Union
import timm
from timm.layers import DropPath, Mlp, trunc_normal_
from .patch_masking import PatchMasking
from dataclasses import dataclass
from .pos_embed import (
get_sincos_pos_embed,
get_1d_sincos_pos_embed,
get_2d_sincos_pos_embed,
get_timestep_embed,
interpolate_pos_embed,
)
[docs]
@dataclass
class MaskedEncoderOutput:
"""Output from MaskedEncoder forward pass.
:ivar encoded: Encoded token representations (B, num_prefix + N_visible, D)
:ivar mask: Binary mask where 1 = masked, 0 = visible (B, N_patches)
:ivar ids_keep: Indices of visible patches (B, N_visible)
:ivar grid_size: Patch grid dimensions (height, width)
"""
encoded: torch.Tensor
mask: torch.Tensor
ids_keep: torch.Tensor
grid_size: Tuple[int, int]
[docs]
class MaskedEncoder(nn.Module):
"""Vision Transformer encoder with optional masking support.
Wraps a timm ViT model and adds flexible masking via :class:`PatchMasking`.
Handles all ViT internals: patch embedding, positional embeddings, prefix
tokens (CLS, registers), and transformer blocks.
:param model_or_model_name: timm model name string or pre-instantiated nn.Module
:param masking: PatchMasking instance. If None, no masking is applied.
:param pretrained: Load pretrained weights (only when model_or_model_name is str)
:param img_size: Override default image size
:param patch_size: Override default patch size (will reinitialize patch_embed)
:param dynamic_img_size: Enable dynamic image size support with pos_embed interpolation
Example::
from spt.backbone import PatchMasking, MaskedEncoder
masking = PatchMasking(mask_ratio=0.75, block_size=4)
encoder = MaskedEncoder(
model_or_model_name="vit_base_patch16_224",
masking=masking,
pretrained=True,
)
images = torch.randn(4, 3, 224, 224)
output = encoder(images)
print(output.encoded.shape) # (4, 1 + 49, 768) with 75% masking
print(output.mask.shape) # (4, 196)
print(output.ids_keep.shape) # (4, 49)
"""
def __init__(
self,
model_or_model_name: Union[str, nn.Module] = "vit_base_patch16_224",
masking: Optional[PatchMasking] = None,
pretrained: bool = False,
img_size: Optional[Union[int, Tuple[int, int]]] = None,
patch_size: Optional[Union[int, Tuple[int, int]]] = None,
dynamic_img_size: bool = False,
):
super().__init__()
self.dynamic_img_size = dynamic_img_size
self.masking = masking
# === Load or use provided encoder ===
if isinstance(model_or_model_name, str):
create_kwargs = {
"pretrained": pretrained,
"num_classes": 0,
"dynamic_img_size": dynamic_img_size,
}
if img_size is not None:
create_kwargs["img_size"] = img_size
if patch_size is not None:
create_kwargs["patch_size"] = patch_size
if pretrained:
print(
f"Warning: Changing patch_size to {patch_size} will reinitialize "
f"patch_embed weights. Pretrained weights won't fully apply."
)
self.vit = timm.create_model(model_or_model_name, **create_kwargs)
else:
self.vit = model_or_model_name
if patch_size is not None:
self._rebuild_patch_embed(patch_size, img_size)
# Remove classification head if present
if hasattr(self.vit, "head") and hasattr(self.vit.head, "in_features"):
self.vit.head = nn.Identity()
# === Cache encoder properties ===
self.embed_dim = self.vit.embed_dim
self.patch_embed = self.vit.patch_embed
ps = self.patch_embed.patch_size
self.patch_size_h, self.patch_size_w = (ps, ps) if isinstance(ps, int) else ps
gs = self.patch_embed.grid_size
self.default_grid_h, self.default_grid_w = (
(gs, gs) if isinstance(gs, int) else gs
)
self.num_prefix_tokens = getattr(self.vit, "num_prefix_tokens", 1)
self.has_class_token = getattr(self.vit, "has_class_token", True)
self.num_reg_tokens = getattr(self.vit, "num_reg_tokens", 0)
self.no_embed_class = getattr(self.vit, "no_embed_class", False)
def _rebuild_patch_embed(
self,
patch_size: Union[int, Tuple[int, int]],
img_size: Optional[Union[int, Tuple[int, int]]] = None,
) -> None:
"""Rebuild patch embedding with new patch size."""
from timm.layers import PatchEmbed
old = self.vit.patch_embed
if img_size is None:
og, op = old.grid_size, old.patch_size
img_size = (
(og[0] * op[0], og[1] * op[1]) if isinstance(og, tuple) else og * op
)
self.vit.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=old.proj.in_channels,
embed_dim=old.proj.out_channels,
)
if old.num_patches != self.vit.patch_embed.num_patches:
self._resize_pos_embed(self.vit.patch_embed.grid_size)
def _resize_pos_embed(self, new_grid_size: Tuple[int, int]) -> None:
"""Resize positional embeddings to new grid size."""
old_pos = self.vit.pos_embed
num_prefix = self.num_prefix_tokens if not self.no_embed_class else 0
src_patches = old_pos.shape[1] - num_prefix
src_size = int(src_patches**0.5)
new_pos = interpolate_pos_embed(
old_pos, (src_size, src_size), new_grid_size, num_prefix
)
self.vit.pos_embed = nn.Parameter(new_pos)
def _get_grid_size(self, images: torch.Tensor) -> Tuple[int, int]:
"""Compute patch grid size from image dimensions."""
H, W = images.shape[-2:]
return H // self.patch_size_h, W // self.patch_size_w
def _get_pos_embed(
self, grid_h: int, grid_w: int
) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
"""Get positional embeddings, interpolating if needed for dynamic size."""
pos_embed = self.vit.pos_embed
num_prefix = self.num_prefix_tokens if not self.no_embed_class else 0
if self.dynamic_img_size and (
grid_h != self.default_grid_h or grid_w != self.default_grid_w
):
src_patches = pos_embed.shape[1] - num_prefix
src_size = int(src_patches**0.5)
pos_embed = interpolate_pos_embed(
pos_embed, (src_size, src_size), (grid_h, grid_w), num_prefix
)
if self.no_embed_class:
return None, pos_embed
return (
pos_embed[:, : self.num_prefix_tokens],
pos_embed[:, self.num_prefix_tokens :],
)
def _get_prefix_tokens(self, B: int) -> Optional[torch.Tensor]:
"""Get CLS and register tokens expanded to batch size."""
tokens = []
if self.has_class_token:
tokens.append(self.vit.cls_token.expand(B, -1, -1))
if self.num_reg_tokens > 0:
tokens.append(self.vit.reg_token.expand(B, -1, -1))
return torch.cat(tokens, dim=1) if tokens else None
[docs]
def forward(self, images: torch.Tensor) -> MaskedEncoderOutput:
"""Encode images with optional masking.
:param images: Input images (B, C, H, W)
:return: MaskedEncoderOutput with encoded tokens and mask info
"""
B = images.shape[0]
device = images.device
grid_h, grid_w = self._get_grid_size(images)
num_patches = grid_h * grid_w
# Patch embed + positional embed
x = self.patch_embed(images)
prefix_pos, patch_pos = self._get_pos_embed(grid_h, grid_w)
x = x + patch_pos
# Apply masking (training only)
if self.training and self.masking is not None:
mask_out = self.masking(x, grid_h, grid_w)
x = mask_out.visible
mask = mask_out.mask
ids_keep = mask_out.ids_keep
else:
mask = torch.zeros(B, num_patches, device=device)
ids_keep = (
torch.arange(num_patches, device=device).unsqueeze(0).expand(B, -1)
)
# Prepend prefix tokens
prefix = self._get_prefix_tokens(B)
if prefix is not None:
if prefix_pos is not None and not self.no_embed_class:
prefix = prefix + prefix_pos
x = torch.cat([prefix, x], dim=1)
# Transformer blocks
x = self.vit.pos_drop(x)
x = self.vit.blocks(x) if hasattr(self.vit, "blocks") else self.vit.layers(x)
x = self.vit.norm(x)
return MaskedEncoderOutput(
encoded=x,
mask=mask,
ids_keep=ids_keep,
grid_size=(grid_h, grid_w),
)
[docs]
def forward_features(self, images: torch.Tensor) -> torch.Tensor:
"""Encode without masking (for inference)."""
was_training = self.training
self.eval()
with torch.no_grad():
output = self.forward(images)
if was_training:
self.train()
return output.encoded
[docs]
def extra_repr(self) -> str:
return (
f"embed_dim={self.embed_dim}, "
f"patch_size=({self.patch_size_h}, {self.patch_size_w}), "
f"num_prefix_tokens={self.num_prefix_tokens}, "
f"has_masking={self.masking is not None}"
)
# =============================================================================
# Efficient Attention Modules
# =============================================================================
[docs]
class Attention(nn.Module):
"""Multi-head self-attention with efficient SDPA backend.
Uses F.scaled_dot_product_attention which automatically selects:
- Flash Attention (when available, fastest)
- Memory-efficient attention (xformers-style)
- Math fallback
:param dim: Input dimension
:param num_heads: Number of attention heads
:param qkv_bias: Add bias to QKV projection
:param attn_drop: Attention dropout rate
:param proj_drop: Output projection dropout rate
"""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
):
super().__init__()
assert dim % num_heads == 0, "dim must be divisible by num_heads"
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
# Fused QKV projection for efficiency
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = attn_drop
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass.
:param x: Input tensor [B, N, D]
:return: Output tensor [B, N, D]
"""
B, N, C = x.shape
# Fused QKV: [B, N, 3*D] -> [B, N, 3, H, head_dim] -> [3, B, H, N, head_dim]
qkv = (
self.qkv(x)
.reshape(B, N, 3, self.num_heads, self.head_dim)
.permute(2, 0, 3, 1, 4)
)
q, k, v = qkv.unbind(0)
# Efficient attention (Flash/Memory-efficient when available)
x = F.scaled_dot_product_attention(
q,
k,
v,
dropout_p=self.attn_drop if self.training else 0.0,
)
# Reshape back: [B, H, N, head_dim] -> [B, N, D]
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
[docs]
class CrossAttention(nn.Module):
"""Multi-head cross-attention with efficient SDPA backend.
Queries attend to key-value pairs from a separate context sequence.
:param dim: Query dimension
:param context_dim: Context dimension (defaults to dim)
:param num_heads: Number of attention heads
:param qkv_bias: Add bias to projections
:param attn_drop: Attention dropout rate
:param proj_drop: Output projection dropout rate
"""
def __init__(
self,
dim: int,
context_dim: Optional[int] = None,
num_heads: int = 8,
qkv_bias: bool = True,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
):
super().__init__()
context_dim = context_dim or dim
assert dim % num_heads == 0, "dim must be divisible by num_heads"
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.kv = nn.Linear(context_dim, dim * 2, bias=qkv_bias)
self.attn_drop = attn_drop
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
[docs]
def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
"""Forward pass.
:param x: Query tensor [B, N, D]
:param context: Key-value tensor [B, M, context_dim]
:return: Output tensor [B, N, D]
"""
B, N, C = x.shape
M = context.shape[1]
# Query projection: [B, N, D] -> [B, H, N, head_dim]
q = self.q(x).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
# KV projection: [B, M, D] -> [B, H, M, head_dim] x2
kv = (
self.kv(context)
.reshape(B, M, 2, self.num_heads, self.head_dim)
.permute(2, 0, 3, 1, 4)
)
k, v = kv.unbind(0)
# Efficient attention
x = F.scaled_dot_product_attention(
q,
k,
v,
dropout_p=self.attn_drop if self.training else 0.0,
)
# Reshape back
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
# =============================================================================
# Transformer Block
# =============================================================================
[docs]
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
"""Apply AdaLN modulation: x * (1 + scale) + shift."""
return x * (1 + scale) + shift
[docs]
class TransformerBlock(nn.Module):
"""Unified transformer block with optional AdaLN-Zero conditioning.
Supports three attention configurations:
**Mode 1: Pure Cross-Attention** (`self_attn=False, cross_attn=True`)
- Queries attend to context but not to each other
- Use case: Lightweight decoder
**Mode 2: Decoder-Style** (`self_attn=True, cross_attn=True`)
- Self-attention on queries, then cross-attention to context
- Use case: Standard decoder (IJEPA predictor, etc.)
**Mode 3: Joint Attention** (`self_attn=True, cross_attn=False`)
- All tokens attend to all tokens (caller concatenates context + queries)
- Use case: Full bidirectional flow (DiT, high masking ratio)
**Conditioning:**
- `use_adaln=True`: AdaLN-Zero modulation (scale, shift, gate per operation)
- `use_adaln=False`: Standard pre-norm transformer
:param dim: Hidden dimension
:param num_heads: Number of attention heads
:param mlp_ratio: MLP hidden dim = dim * mlp_ratio
:param self_attn: Enable self-attention
:param cross_attn: Enable cross-attention
:param use_adaln: Enable AdaLN-Zero conditioning
:param drop_path: Stochastic depth rate
:param attn_drop: Attention dropout rate
:param proj_drop: Projection dropout rate
:param act_layer: Activation layer for MLP
"""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
self_attn: bool = True,
cross_attn: bool = True,
use_adaln: bool = True,
drop_path: float = 0.0,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
act_layer: type = nn.GELU,
):
super().__init__()
self.use_self_attn = self_attn
self.use_cross_attn = cross_attn
self.use_adaln = use_adaln
if not self_attn and not cross_attn:
raise ValueError("At least one of self_attn or cross_attn must be True")
# Self-attention
if self_attn:
self.norm1 = nn.LayerNorm(dim, elementwise_affine=not use_adaln)
self.attn = Attention(
dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=proj_drop
)
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
# Cross-attention
if cross_attn:
self.norm2_q = nn.LayerNorm(dim, elementwise_affine=not use_adaln)
self.norm2_kv = nn.LayerNorm(dim, elementwise_affine=not use_adaln)
self.cross_attn = CrossAttention(
dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=proj_drop
)
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
# MLP
self.norm3 = nn.LayerNorm(dim, elementwise_affine=not use_adaln)
self.mlp = Mlp(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=proj_drop,
)
self.drop_path3 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
# AdaLN modulation MLP
if use_adaln:
# 3 params (shift, scale, gate) per operation
num_ops = int(self_attn) + int(cross_attn) + 1 # +1 for MLP
self.num_mods = num_ops * 3
self.adaLN_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(dim, self.num_mods * dim),
)
# Zero-init for identity initialization
nn.init.zeros_(self.adaLN_mlp[1].weight)
nn.init.zeros_(self.adaLN_mlp[1].bias)
[docs]
def forward(
self,
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
cond: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass.
:param x: Input tensor [B, N, D]
:param context: Context for cross-attention [B, M, D] (required if cross_attn=True)
:param cond: Conditioning tensor [B, D] (required if use_adaln=True)
:return: Output tensor [B, N, D]
"""
if self.use_cross_attn and context is None:
raise ValueError("context required when cross_attn=True")
if self.use_adaln and cond is None:
raise ValueError("cond required when use_adaln=True")
if self.use_adaln:
# Get modulation parameters: [B, num_mods * D] -> list of [B, 1, D]
mods = self.adaLN_mlp(cond).chunk(self.num_mods, dim=-1)
mods = [m.unsqueeze(1) for m in mods]
i = 0
# Self-attention with AdaLN
if self.use_self_attn:
shift, scale, gate = mods[i], mods[i + 1], mods[i + 2]
i += 3
x = x + gate * self.drop_path1(
self.attn(modulate(self.norm1(x), shift, scale))
)
# Cross-attention with AdaLN
if self.use_cross_attn:
shift, scale, gate = mods[i], mods[i + 1], mods[i + 2]
i += 3
q = modulate(self.norm2_q(x), shift, scale)
kv = self.norm2_kv(context)
x = x + gate * self.drop_path2(self.cross_attn(q, kv))
# MLP with AdaLN
shift, scale, gate = mods[i], mods[i + 1], mods[i + 2]
x = x + gate * self.drop_path3(
self.mlp(modulate(self.norm3(x), shift, scale))
)
else:
# Standard pre-norm transformer (no conditioning)
if self.use_self_attn:
x = x + self.drop_path1(self.attn(self.norm1(x)))
if self.use_cross_attn:
x = x + self.drop_path2(
self.cross_attn(self.norm2_q(x), self.norm2_kv(context))
)
x = x + self.drop_path3(self.mlp(self.norm3(x)))
return x
[docs]
class FlexibleTransformer(nn.Module):
"""Flexible transformer supporting multiple architectures.
Unified backbone for:
- **MAE decoder**: `self_attn=True, cross_attn=False, use_adaln=False`
- **IJEPA predictor**: `self_attn=True, cross_attn=True, use_adaln=False`
- **DiT / Flow**: `self_attn=True, cross_attn=True/False, use_adaln=True`
- **MaskGIT**: `self_attn=True, cross_attn=False, use_adaln=True, add_mask_token=True`
:param input_dim: Input embedding dimension (from encoder)
:param hidden_dim: Internal transformer dimension
:param output_dim: Output dimension
:param num_patches: Total number of patches (for positional embeddings)
:param depth: Number of transformer blocks
:param num_heads: Number of attention heads
:param mlp_ratio: MLP hidden dim multiplier
:param self_attn: Enable self-attention in blocks
:param cross_attn: Enable cross-attention in blocks
:param use_adaln: Enable AdaLN-Zero conditioning
:param pos_embed_type: 'sincos_1d', 'sincos_2d', or 'learned'
:param grid_size: Grid size for 2D positional embeddings
:param drop_path_rate: Stochastic depth rate (linearly increases through layers)
:param attn_drop: Attention dropout rate
:param proj_drop: Projection dropout rate
:param zero_init_output: Zero-initialize output projection
:param num_prefix_tokens: Number of prefix tokens (e.g., CLS token)
:param add_mask_token: Enable learnable [MASK] token for masked prediction.
When enabled, use `context_mask` and/or `query_mask` in forward() to
replace tokens at specified positions with the [MASK] token.
Example::
# MAE decoder
decoder = FlexibleTransformer(
768,
512,
768,
196,
depth=8,
self_attn=True,
cross_attn=False,
use_adaln=False,
)
out = decoder(context, queries, context_idx, query_idx)
# IJEPA predictor
predictor = FlexibleTransformer(
768,
384,
768,
196,
depth=6,
self_attn=True,
cross_attn=True,
use_adaln=False,
)
out = predictor(context, queries, context_idx, query_idx)
# DiT-style flow matching
flow = FlexibleTransformer(
768,
384,
768,
196,
depth=12,
self_attn=True,
cross_attn=False,
use_adaln=True,
)
out = flow(context, queries, context_idx, query_idx, t=timesteps)
# MaskGIT-style: variable number of masks per sample
maskgit = FlexibleTransformer(
768,
512,
768,
196,
depth=8,
self_attn=True,
cross_attn=False,
use_adaln=True,
add_mask_token=True,
)
# Each sample can have different number of masked positions
# context_mask[b, i] = True means replace context[b, i] with [MASK]
context_mask = torch.rand(B, num_patches) < mask_ratio # Variable per sample!
out = maskgit(
context=all_patches, # [B, 196, D]
queries=all_patches[:, :0], # [B, 0, D] empty
context_idx=torch.arange(196).expand(B, -1), # [B, 196]
query_idx=torch.empty(B, 0, dtype=torch.long),
context_mask=context_mask, # [B, 196] bool, variable True count
t=timesteps,
return_all=True,
) # Returns [B, 196, output_dim]
# BERT-style MLM: mask random tokens in sequence
bert = FlexibleTransformer(
768,
768,
768,
512,
depth=12,
self_attn=True,
cross_attn=False,
use_adaln=False,
add_mask_token=True,
)
# Random 15% masking, different positions per sample
context_mask = torch.rand(B, seq_len) < 0.15
out = bert(
context=token_embeddings,
queries=token_embeddings[:, :0],
context_idx=position_ids,
query_idx=torch.empty(B, 0, dtype=torch.long),
context_mask=context_mask,
return_all=True,
)
"""
def __init__(
self,
input_dim: int = 768,
hidden_dim: int = 384,
output_dim: int = 768,
num_patches: int = 196,
depth: int = 4,
num_heads: int = 6,
mlp_ratio: float = 4.0,
self_attn: bool = True,
cross_attn: bool = True,
use_adaln: bool = True,
pos_embed_type: Literal["sincos_1d", "sincos_2d", "learned"] = "sincos_2d",
grid_size: Optional[int | tuple[int, int]] = None,
drop_path_rate: float = 0.0,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
zero_init_output: bool = True,
num_prefix_tokens: int = 1,
add_mask_token: bool = False,
):
super().__init__()
if hidden_dim % num_heads != 0:
raise ValueError(
f"hidden_dim ({hidden_dim}) must be divisible by num_heads ({num_heads})"
)
self.hidden_dim = hidden_dim
self.num_prefix_tokens = num_prefix_tokens
self.use_cross_attn = cross_attn
self.use_adaln = use_adaln
self.add_mask_token = add_mask_token
# Input/output projections
self.context_proj = nn.Linear(input_dim, hidden_dim)
self.query_proj = nn.Linear(input_dim, hidden_dim)
self.output_proj = nn.Linear(hidden_dim, output_dim)
if zero_init_output:
nn.init.zeros_(self.output_proj.weight)
nn.init.zeros_(self.output_proj.bias)
# Positional embeddings
if pos_embed_type == "sincos_2d":
if grid_size is None:
grid_size = int(num_patches**0.5)
if grid_size**2 != num_patches:
raise ValueError(
f"num_patches ({num_patches}) must be a perfect square for sincos_2d"
)
pe = get_sincos_pos_embed(
hidden_dim, num_patches, mode="2d", grid_size=grid_size
)
self.register_buffer("pos_embed", pe.unsqueeze(0))
elif pos_embed_type == "sincos_1d":
pe = get_sincos_pos_embed(hidden_dim, num_patches, mode="1d")
self.register_buffer("pos_embed", pe.unsqueeze(0))
else: # learned
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_dim))
trunc_normal_(self.pos_embed, std=0.02)
# Prefix token positional embeddings
if num_prefix_tokens > 0:
self.prefix_pos_embed = nn.Parameter(
torch.zeros(1, num_prefix_tokens, hidden_dim)
)
nn.init.normal_(self.prefix_pos_embed, std=0.02)
# Learnable mask token (shared for context and query masking)
if add_mask_token:
self.mask_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))
nn.init.normal_(self.mask_token, std=0.02)
# Time embedding MLP (only needed for AdaLN)
if use_adaln:
self.time_mlp = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim * 4),
nn.GELU(),
nn.Linear(hidden_dim * 4, hidden_dim),
)
# Transformer blocks with linearly increasing drop path
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
self.blocks = nn.ModuleList(
[
TransformerBlock(
dim=hidden_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
self_attn=self_attn,
cross_attn=cross_attn,
use_adaln=use_adaln,
drop_path=dpr[i],
attn_drop=attn_drop,
proj_drop=proj_drop,
)
for i in range(depth)
]
)
self.final_norm = nn.LayerNorm(hidden_dim)
def _gather_pos(self, idx: torch.Tensor, num_prefix: int = 0) -> torch.Tensor:
"""Gather positional embeddings for given indices."""
B = idx.shape[0]
if num_prefix > 0:
prefix_pos = self.prefix_pos_embed.expand(B, -1, -1)
patch_idx = idx[:, num_prefix:]
patch_pos = torch.gather(
self.pos_embed.expand(B, -1, -1),
dim=1,
index=patch_idx.unsqueeze(-1).expand(-1, -1, self.hidden_dim),
)
return torch.cat([prefix_pos, patch_pos], dim=1)
else:
idx = idx.unsqueeze(-1).expand(-1, -1, self.hidden_dim)
return torch.gather(self.pos_embed.expand(B, -1, -1), 1, idx)
[docs]
def forward(
self,
context: torch.Tensor,
queries: torch.Tensor,
context_idx: torch.Tensor,
query_idx: torch.Tensor,
t: Optional[torch.Tensor] = None,
num_prefix: Optional[int] = None,
return_all: bool = False,
context_mask: Optional[torch.Tensor] = None,
query_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass.
:param context: Context token embeddings [B, N_ctx, input_dim]
:param queries: Query token embeddings [B, N_qry, input_dim]
:param context_idx: Patch indices for context tokens [B, N_ctx]
:param query_idx: Patch indices for query tokens [B, N_qry]
:param t: Timesteps for conditioning [B] (required if use_adaln=True)
:param num_prefix: Override for number of prefix tokens in context
:param return_all: If True and using joint attention (cross_attn=False),
return all tokens unshuffled to original position order.
Output shape: [B, N_ctx + N_qry, output_dim].
Ignored for cross-attention modes.
:param context_mask: Boolean mask indicating which context tokens to replace
with [MASK] token [B, N_ctx]. True = replace with mask. Each sample can
have a different number of True values. Requires add_mask_token=True.
:param query_mask: Boolean mask indicating which query tokens to replace
with [MASK] token [B, N_qry]. True = replace with mask. Each sample can
have a different number of True values. Requires add_mask_token=True.
:return: Output embeddings. Shape depends on mode:
- cross_attn=True: [B, N_qry, output_dim]
- cross_attn=False, return_all=False: [B, N_qry, output_dim]
- cross_attn=False, return_all=True: [B, N_ctx + N_qry, output_dim].
"""
# Validate mask token usage
if context_mask is not None or query_mask is not None:
if not self.add_mask_token:
raise ValueError(
"context_mask or query_mask provided but "
"add_mask_token=False at initialization"
)
if num_prefix is None:
num_prefix = self.num_prefix_tokens
# Project context and optionally replace masked positions with [MASK] token
context = self.context_proj(context)
if context_mask is not None:
mask_tokens = self.mask_token.expand_as(context)
context = torch.where(context_mask.unsqueeze(-1), mask_tokens, context)
context = context + self._gather_pos(context_idx, num_prefix)
# Project queries and optionally replace masked positions with [MASK] token
queries = self.query_proj(queries)
if query_mask is not None:
mask_tokens = self.mask_token.expand_as(queries)
queries = torch.where(query_mask.unsqueeze(-1), mask_tokens, queries)
queries = queries + self._gather_pos(query_idx)
# Time conditioning (only for AdaLN mode)
cond = None
if self.use_adaln:
if t is None:
raise ValueError("Timestep t required when use_adaln=True")
cond = self.time_mlp(get_timestep_embed(t, self.hidden_dim))
n_context = context.shape[1]
n_queries = queries.shape[1]
if self.use_cross_attn:
# Cross-attention mode: queries attend to context
for block in self.blocks:
queries = block(queries, context=context, cond=cond)
return self.output_proj(self.final_norm(queries))
# Joint attention mode
x = torch.cat([context, queries], dim=1)
for block in self.blocks:
x = block(x, cond=cond)
x = self.final_norm(x)
if return_all:
# Unshuffle to original positions
B = context_idx.shape[0]
T = n_context + n_queries
out = torch.empty(B, T, self.hidden_dim, device=x.device, dtype=x.dtype)
out.scatter_(
dim=1,
index=context_idx.unsqueeze(-1).expand(-1, -1, self.hidden_dim),
src=x[:, :n_context],
)
out.scatter_(
dim=1,
index=query_idx.unsqueeze(-1).expand(-1, -1, self.hidden_dim),
src=x[:, n_context:],
)
return self.output_proj(out)
# Return only query outputs
if n_queries == 0:
B = context.shape[0]
return torch.empty(
B, 0, self.output_proj.out_features, device=x.device, dtype=x.dtype
)
return self.output_proj(x[:, -n_queries:])
[docs]
class TransformerPredictor(nn.Module):
"""Lightweight transformer predictor using TransformerBlock.
A flexible predictor module commonly used in masked image modeling (e.g., MAE,
I-JEPA). Processes context tokens and optionally includes learnable register/query
tokens for aggregation.
:param input_dim: Dimension of input context tokens
:param hidden_dim: Internal dimension of transformer layers
:param output_dim: Dimension of output tokens
:param depth: Number of transformer layers
:param num_heads: Number of attention heads
:param num_registers: Number of learnable register/query tokens to prepend
:param mlp_ratio: MLP hidden dimension multiplier
:param drop_path_rate: Stochastic depth rate
:param pos_embed_type: Type of positional embedding (None, 'sincos_1d', 'sincos_2d', 'learned')
:param max_seq_len: Maximum sequence length (required if pos_embed_type='learned')
"""
def __init__(
self,
input_dim: int,
hidden_dim: int,
output_dim: int,
depth: int,
num_heads: int = 6,
num_registers: int = 0,
mlp_ratio: float = 4.0,
drop_path_rate: float = 0.0,
pos_embed_type: Literal["sincos_1d", "sincos_2d", "learned"] | None = None,
max_seq_len: int | None = None,
):
super().__init__()
self.hidden_dim = hidden_dim
self.num_registers = num_registers
self.pos_embed_type = pos_embed_type
# Projections
self.input_proj = nn.Linear(input_dim, hidden_dim)
self.output_proj = nn.Linear(hidden_dim, output_dim)
# Register tokens
if num_registers > 0:
self.register_tokens = nn.Parameter(
torch.zeros(1, num_registers, hidden_dim)
)
self.register_pos_embed = nn.Parameter(
torch.zeros(1, num_registers, hidden_dim)
)
nn.init.normal_(self.register_tokens, std=0.02)
nn.init.normal_(self.register_pos_embed, std=0.02)
# Learned positional embeddings (sincos computed on-the-fly)
if pos_embed_type == "learned":
assert max_seq_len is not None, "max_seq_len required for learned pos_embed"
self.pos_embed = nn.Parameter(torch.zeros(1, max_seq_len, hidden_dim))
nn.init.normal_(self.pos_embed, std=0.02)
# Transformer blocks
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
self.blocks = nn.ModuleList(
[
TransformerBlock(
dim=hidden_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
self_attn=True,
cross_attn=False,
use_adaln=False,
drop_path=dpr[i],
)
for i in range(depth)
]
)
self.norm = nn.LayerNorm(hidden_dim)
def _get_pos_embed(
self,
ids_keep: torch.Tensor,
grid_size: tuple[int, int] | None,
) -> torch.Tensor:
"""Gather or generate positional embeddings."""
B, N = ids_keep.shape
device, dtype = ids_keep.device, self.input_proj.weight.dtype
if self.pos_embed_type == "learned":
return torch.gather(
self.pos_embed.expand(B, -1, -1),
dim=1,
index=ids_keep.unsqueeze(-1).expand(-1, -1, self.hidden_dim),
)
# Generate sincos on-the-fly
if self.pos_embed_type == "sincos_1d":
max_pos = int(ids_keep.max().item()) + 1
pe = get_1d_sincos_pos_embed(self.hidden_dim, max_pos)
else: # sincos_2d
pe = get_2d_sincos_pos_embed(self.hidden_dim, grid_size)
pe = pe.to(device=device, dtype=dtype).unsqueeze(0).expand(B, -1, -1)
return torch.gather(
pe, 1, ids_keep.unsqueeze(-1).expand(-1, -1, self.hidden_dim)
)
[docs]
def forward(
self,
context: torch.Tensor,
pos_embed: torch.Tensor | None = None,
ids_keep: torch.Tensor | None = None,
grid_size: tuple[int, int] | None = None,
) -> torch.Tensor:
"""Forward pass.
:param context: Context tokens [B, N, input_dim]
:param pos_embed: External positional embeddings [B, N, input_dim] (when pos_embed_type=None)
:param ids_keep: Indices of kept positions [B, N] (when pos_embed_type is not None)
:param grid_size: Grid size (H, W) for sincos_2d
:return: Output tokens [B, num_registers + N, output_dim]
"""
B = context.shape[0]
# Project to hidden dim
x = self.input_proj(context)
# Add positional embeddings
if self.pos_embed_type is not None:
x = x + self._get_pos_embed(ids_keep, grid_size)
elif pos_embed is not None:
x = x + self.input_proj(pos_embed)
# Prepend registers
if self.num_registers > 0:
registers = self.register_tokens.expand(B, -1, -1) + self.register_pos_embed
x = torch.cat([registers, x], dim=1)
# Transformer blocks
for block in self.blocks:
x = block(x)
return self.output_proj(self.norm(x))
[docs]
class MAEDecoder(nn.Module):
"""MAE-style Vision Transformer Decoder using FlexibleTransformer.
Implements the decoder component of Masked Autoencoders (MAE) [1]_ for
self-supervised visual representation learning. The decoder reconstructs
masked patches from visible patch embeddings using joint self-attention,
where visible tokens and learnable mask tokens attend to each other.
The decoder is intentionally lightweight compared to the encoder, as MAE
demonstrates that a shallow decoder is sufficient for pixel reconstruction
while keeping the encoder focused on learning semantic representations.
Architecture Overview
---------------------
1. **Input projection**: Maps encoder embeddings (embed_dim) to decoder
dimension (decoder_embed_dim)
2. **Mask token expansion**: Learnable mask tokens are placed at masked
positions
3. **Positional encoding**: Adds position information to all tokens
4. **Transformer blocks**: Joint self-attention over visible + mask tokens
5. **Output projection**: Maps to output_dim (typically patch_size² × channels)
Parameters
----------
embed_dim : int, default=768
Embedding dimension from the encoder. This is the input dimension
of visible tokens passed to the decoder.
decoder_embed_dim : int, default=512
Internal hidden dimension of the decoder transformer blocks.
Typically smaller than embed_dim for efficiency.
output_dim : int, default=768
Output dimension per token. For pixel reconstruction, this should be
``patch_size ** 2 * in_channels`` (e.g., 16×16×3 = 768 for RGB).
num_patches : int, default=196
Total number of patches T in the image (e.g., 14×14 = 196 for
224×224 images with patch_size=16).
depth : int, default=4
Number of transformer blocks in the decoder. MAE typically uses
fewer blocks than the encoder (e.g., 4-8 vs 12-24).
num_heads : int, default=16
Number of attention heads in multi-head self-attention.
mlp_ratio : float, default=4.0
Expansion ratio for the MLP hidden dimension relative to
decoder_embed_dim.
pos_embed_type : {'sincos_1d', 'sincos_2d', 'learned'}, default='sincos_2d'
Type of positional embedding:
- 'sincos_2d': Fixed 2D sinusoidal (recommended for images)
- 'sincos_1d': Fixed 1D sinusoidal
- 'learned': Learnable positional embeddings
grid_size : int, optional
Spatial grid size for 2D positional embeddings. If None, inferred
as ``int(sqrt(num_patches))``. Required for non-square grids.
drop_path_rate : float, default=0.0
Stochastic depth rate for regularization during training.
Attributes:
----------
mask_token : nn.Parameter
Learnable token of shape (1, 1, embed_dim) used to represent
masked positions. Initialized with truncated normal (std=0.02).
transformer : FlexibleTransformer
Core transformer module handling attention and projections.
Notes:
-----
- The mask convention follows MAE: **0 = visible/kept, 1 = masked**
- The decoder receives visible tokens and reconstructs masked positions
- For efficiency, only masked positions are predicted by default
References:
----------
.. [1] He, K., et al. "Masked Autoencoders Are Scalable Vision Learners."
CVPR 2022. https://arxiv.org/abs/2111.06377
Examples:
--------
**Basic Usage with MAE Encoder**
>>> import torch
>>> import torch.nn as nn
>>>
>>> # Configuration matching ViT-Base
>>> B, T = 4, 196 # batch size, num_patches (14x14)
>>> embed_dim = 768 # encoder dimension
>>> mask_ratio = 0.75 # MAE default: mask 75% of patches
>>>
>>> # Initialize decoder
>>> decoder = MAEDecoder(
... embed_dim=embed_dim,
... decoder_embed_dim=512,
... output_dim=16 * 16 * 3, # patch_size² × channels = 768
... num_patches=T,
... depth=4,
... num_heads=16,
... )
>>>
>>> # Simulate encoder output (visible tokens only)
>>> N_vis = int(T * (1 - mask_ratio)) # 49 visible patches
>>> visible_tokens = torch.randn(B, N_vis, embed_dim)
>>>
>>> # Create random mask (0=visible, 1=masked)
>>> mask = torch.zeros(B, T)
>>> for i in range(B):
... masked_indices = torch.randperm(T)[: T - N_vis]
... mask[i, masked_indices] = 1
>>>
>>> # Decode - predict masked patches only
>>> pred_masked = decoder(visible_tokens, mask, output_masked_only=True)
>>> print(pred_masked.shape) # [B, N_mask, output_dim]
torch.Size([4, 147, 768])
**Full Sequence Reconstruction**
>>> # Get predictions for ALL positions (for visualization)
>>> pred_full = decoder(visible_tokens, mask, output_masked_only=False)
>>> print(pred_full.shape) # [B, T, output_dim]
torch.Size([4, 196, 768])
**Using Full Sequence Input**
If you have the full sequence with mask tokens already inserted:
>>> full_sequence = torch.randn(B, T, embed_dim) # [B, 196, 768]
>>> pred = decoder(full_sequence, mask, output_masked_only=True)
>>> print(pred.shape)
torch.Size([4, 147, 768])
**Integration with MAE Training Loop**
>>> # Typical MAE training step (pseudocode)
>>> def mae_forward(encoder, decoder, images, mask_ratio=0.75):
... # Patchify and mask
... patches = patchify(images) # [B, T, patch_dim]
... mask = random_mask(B, T, mask_ratio) # [B, T], 0=keep, 1=mask
...
... # Encode visible patches only
... visible_patches = patches[~mask.bool()].reshape(B, -1, patch_dim)
... latent = encoder(visible_patches) # [B, N_vis, embed_dim]
...
... # Decode to predict masked patches
... pred = decoder(
... latent, mask, output_masked_only=True
... ) # [B, N_mask, output_dim]
...
... # Reconstruction loss on masked patches only
... target = patches[mask.bool()].reshape(B, -1, patch_dim)
... loss = F.mse_loss(pred, target)
... return loss
**Custom Configuration for ViT-Large**
>>> decoder_large = MAEDecoder(
... embed_dim=1024, # ViT-L encoder dim
... decoder_embed_dim=512, # Keep decoder lightweight
... output_dim=768, # 16×16×3 pixels
... num_patches=256, # 16×16 patches for 256×256 images
... depth=8, # Slightly deeper
... num_heads=16,
... pos_embed_type="sincos_2d",
... drop_path_rate=0.1, # Regularization
... )
See Also:
--------
FlexibleTransformer : Core transformer implementation used internally.
"""
def __init__(
self,
embed_dim: int = 768,
decoder_embed_dim: int = 512,
output_dim: int = 768,
num_patches: int = 196,
depth: int = 4,
num_heads: int = 16,
mlp_ratio: float = 4.0,
pos_embed_type: Literal["sincos_1d", "sincos_2d", "learned"] = "sincos_2d",
grid_size: Optional[int] = None,
drop_path_rate: float = 0.0,
):
super().__init__()
self.num_patches = num_patches
# Learnable mask token
self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
trunc_normal_(self.mask_token, std=0.02)
# Core transformer
self.transformer = FlexibleTransformer(
input_dim=embed_dim,
hidden_dim=decoder_embed_dim,
output_dim=output_dim,
num_patches=num_patches,
depth=depth,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
self_attn=True,
cross_attn=False,
use_adaln=False,
pos_embed_type=pos_embed_type,
grid_size=grid_size,
drop_path_rate=drop_path_rate,
zero_init_output=False,
num_prefix_tokens=0,
)
[docs]
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor,
output_masked_only: bool = False,
) -> torch.Tensor:
"""Forward pass.
:param x: Visible tokens [B, N_vis, D] or full sequence [B, T, D]
:param mask: Binary mask [B, T], 0=kept, 1=masked
:param output_masked_only: If True, return [B, N_mask, D].
If False, return [B, T, D].
:return: Predictions
"""
B, T = mask.shape
mask_bool = mask.bool() # Convert once, use everywhere
N_vis = (~mask_bool).sum(dim=1)[0].int().item()
N_mask = T - N_vis
# Get indices (sort False/0 before True/1, so visible indices come first)
visible_idx = torch.argsort(mask_bool.int(), dim=1, stable=True)[:, :N_vis]
masked_idx = torch.argsort((~mask_bool).int(), dim=1, stable=True)[:, :N_mask]
# Get visible tokens
if x.shape[1] == T:
visible_tokens = torch.gather(
x, dim=1, index=visible_idx.unsqueeze(-1).expand(-1, -1, x.shape[-1])
)
else:
visible_tokens = x
# Mask tokens for masked positions
mask_tokens = self.mask_token.expand(B, N_mask, -1)
return self.transformer(
context=visible_tokens,
queries=mask_tokens,
context_idx=visible_idx,
query_idx=masked_idx,
return_all=not output_masked_only,
)
[docs]
class PositionalEncoding2D(nn.Module):
"""Flexible 2D positional encoding for vision transformers."""
def __init__(
self,
embed_dim: int,
grid_size: Tuple[int, int],
pos_type: Literal["learnable", "sinusoidal", "rope", "none"] = "learnable",
num_prefix_tokens: int = 1,
learnable: Optional[
bool
] = None, # Override: force learnable even for sinusoidal
):
"""Positional encoding for 2d input.
:param embed_dim: Embedding dimension
:param grid_size: (H, W) grid size in patches
:param pos_type: Type of positional encoding
:param num_prefix_tokens: Number of prefix tokens (CLS + registers)
:param learnable: If True, make sinusoidal learnable; if None, use default
"""
super().__init__()
self.embed_dim = embed_dim
self.grid_h, self.grid_w = grid_size
self.num_patches = self.grid_h * self.grid_w
self.pos_type = pos_type
self.num_prefix_tokens = num_prefix_tokens
# Override learnable if specified
if learnable is not None:
self.is_learnable = learnable
else:
self.is_learnable = pos_type == "learnable"
if pos_type == "none":
# No positional encoding
self.pos_embed = None
elif pos_type == "learnable":
# Learnable absolute positional embeddings
self.pos_embed = nn.Parameter(
torch.zeros(1, num_prefix_tokens + self.num_patches, embed_dim)
)
nn.init.trunc_normal_(self.pos_embed, std=0.02)
elif pos_type == "sinusoidal":
# 2D sinusoidal positional embeddings
pos_embed = self._build_sinusoidal_2d(embed_dim, self.grid_h, self.grid_w)
# Add prefix token positions (zeros or learned separately)
prefix_pos = torch.zeros(1, num_prefix_tokens, embed_dim)
pos_embed = torch.cat([prefix_pos, pos_embed], dim=1)
if self.is_learnable:
self.pos_embed = nn.Parameter(pos_embed)
else:
self.register_buffer("pos_embed", pos_embed)
elif pos_type == "rope":
# RoPE doesn't use additive embeddings
self.pos_embed = None
# Precompute RoPE frequencies
self.register_buffer(
"freqs_h", self._build_rope_freqs(embed_dim // 4, self.grid_h)
)
self.register_buffer(
"freqs_w", self._build_rope_freqs(embed_dim // 4, self.grid_w)
)
else:
raise ValueError(f"Unknown pos_type: {pos_type}")
def _build_sinusoidal_2d(
self, embed_dim: int, grid_h: int, grid_w: int
) -> torch.Tensor:
"""Build 2D sinusoidal positional embeddings."""
assert embed_dim % 4 == 0, "embed_dim must be divisible by 4 for 2D sinusoidal"
dim_h = embed_dim // 2
dim_w = embed_dim // 2
# Height positions
pos_h = torch.arange(grid_h).unsqueeze(1) # [H, 1]
dim_t_h = torch.arange(0, dim_h, 2).float() # [dim_h/2]
omega_h = 1.0 / (10000 ** (dim_t_h / dim_h))
pos_embed_h = torch.zeros(grid_h, dim_h)
pos_embed_h[:, 0::2] = torch.sin(pos_h * omega_h)
pos_embed_h[:, 1::2] = torch.cos(pos_h * omega_h)
# Width positions
pos_w = torch.arange(grid_w).unsqueeze(1) # [W, 1]
dim_t_w = torch.arange(0, dim_w, 2).float()
omega_w = 1.0 / (10000 ** (dim_t_w / dim_w))
pos_embed_w = torch.zeros(grid_w, dim_w)
pos_embed_w[:, 0::2] = torch.sin(pos_w * omega_w)
pos_embed_w[:, 1::2] = torch.cos(pos_w * omega_w)
# Combine: [H, W, D]
pos_embed_h = pos_embed_h.unsqueeze(1).expand(-1, grid_w, -1) # [H, W, dim_h]
pos_embed_w = pos_embed_w.unsqueeze(0).expand(grid_h, -1, -1) # [H, W, dim_w]
pos_embed = torch.cat([pos_embed_h, pos_embed_w], dim=-1) # [H, W, D]
pos_embed = pos_embed.reshape(1, grid_h * grid_w, embed_dim) # [1, H*W, D]
return pos_embed
def _build_rope_freqs(
self, dim: int, max_seq_len: int, base: float = 10000.0
) -> torch.Tensor:
"""Build RoPE frequency tensor."""
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
pos = torch.arange(max_seq_len)
freqs = torch.einsum("i,j->ij", pos, inv_freq) # [seq_len, dim/2]
freqs = torch.cat([freqs, freqs], dim=-1) # [seq_len, dim]
return freqs
def _apply_rope_2d(self, x: torch.Tensor, grid_h: int, grid_w: int) -> torch.Tensor:
"""Apply 2D RoPE to patch tokens."""
B, N, D = x.shape
# Separate prefix and patch tokens
prefix = x[:, : self.num_prefix_tokens, :]
patches = x[:, self.num_prefix_tokens :, :] # [B, H*W, D]
# Reshape to 2D grid
patches = patches.reshape(B, grid_h, grid_w, D)
# Split embedding into 4 parts for 2D RoPE
d_quarter = D // 4
x1, x2, x3, x4 = patches.split(d_quarter, dim=-1)
# Get frequencies (interpolate if needed)
freqs_h = self.freqs_h[:grid_h, :d_quarter] # [H, d_quarter]
freqs_w = self.freqs_w[:grid_w, :d_quarter] # [W, d_quarter]
# Apply rotation to height dimension (x1, x2)
cos_h = torch.cos(freqs_h).unsqueeze(1) # [H, 1, d_quarter]
sin_h = torch.sin(freqs_h).unsqueeze(1) # [H, 1, d_quarter]
x1_rot = x1 * cos_h - x2 * sin_h
x2_rot = x1 * sin_h + x2 * cos_h
# Apply rotation to width dimension (x3, x4)
cos_w = torch.cos(freqs_w).unsqueeze(0) # [1, W, d_quarter]
sin_w = torch.sin(freqs_w).unsqueeze(0) # [1, W, d_quarter]
x3_rot = x3 * cos_w - x4 * sin_w
x4_rot = x3 * sin_w + x4 * cos_w
# Combine
patches = torch.cat([x1_rot, x2_rot, x3_rot, x4_rot], dim=-1)
patches = patches.reshape(B, grid_h * grid_w, D)
# Recombine with prefix (prefix tokens don't get RoPE)
return torch.cat([prefix, patches], dim=1)
[docs]
def forward(
self, x: torch.Tensor, grid_size: Optional[Tuple[int, int]] = None
) -> torch.Tensor:
"""Apply positional encoding.
:param x: [B, num_prefix + num_patches, D]
:param grid_size: (H, W) if different from default (for dynamic size)
:return: x with positional encoding applied
"""
if self.pos_type == "none":
return x
grid_h = grid_size[0] if grid_size else self.grid_h
grid_w = grid_size[1] if grid_size else self.grid_w
if self.pos_type == "rope":
return self._apply_rope_2d(x, grid_h, grid_w)
# Additive positional embeddings (learnable or sinusoidal)
pos_embed = self.pos_embed
# Interpolate if dynamic size
if grid_h != self.grid_h or grid_w != self.grid_w:
pos_embed = self._interpolate(pos_embed, grid_h, grid_w)
return x + pos_embed
def _interpolate(
self, pos_embed: torch.Tensor, target_h: int, target_w: int
) -> torch.Tensor:
"""Interpolate positional embeddings to new grid size."""
prefix_pos = pos_embed[:, : self.num_prefix_tokens, :]
patch_pos = pos_embed[:, self.num_prefix_tokens :, :]
D = patch_pos.shape[-1]
patch_pos = patch_pos.reshape(1, self.grid_h, self.grid_w, D).permute(
0, 3, 1, 2
)
patch_pos = F.interpolate(
patch_pos, size=(target_h, target_w), mode="bicubic", align_corners=False
)
patch_pos = patch_pos.permute(0, 2, 3, 1).reshape(1, target_h * target_w, D)
return torch.cat([prefix_pos, patch_pos], dim=1)