import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, Literal, Union
import timm
from timm.layers import DropPath, Mlp, trunc_normal_
from loguru import logger
from .patch_masking import PatchMasking
from dataclasses import dataclass
from transformers.utils import ModelOutput
from .pos_embed import (
get_sincos_pos_embed,
get_timestep_embed,
interpolate_pos_embed,
RotaryPositionEmbedding2D,
)
[docs]
class SwiGLU(nn.Module):
"""SwiGLU: Gated Linear Unit with Swish activation.
A parameter-efficient gated activation that combines the benefits of
gating mechanisms with the smooth, non-monotonic Swish activation.
Empirically improves transformer performance over standard GeLU MLPs.
Architecture
------------
Standard MLP::
x → Linear → GeLU → Linear → out
Parameters: 2 * d * h
SwiGLU::
x → Linear(W₁) → SiLU ─┐
├─ element-wise multiply → Linear(W₃) → out
x → Linear(W₂) ────────┘
Parameters: 3 * d * h' (where h' = 2h/3 to match param count)
The hidden dimension is scaled to ``2/3 * hidden_features`` so that
total parameter count matches a standard 2-layer MLP:
``3 * d * (2h/3) = 2 * d * h``
Performance Benefits
--------------------
- **Better gradient flow**: Gating provides multiplicative paths
- **Smoother optimization**: SiLU (Swish) is smooth and non-monotonic
- **Quality**: Consistently outperforms GeLU in language and vision models
:param in_features: Input dimension
:param hidden_features: Nominal hidden dimension. Actual hidden size is
``int(2 * hidden_features / 3)`` to maintain parameter parity with
standard MLPs.
:param out_features: Output dimension. Defaults to ``in_features``.
:param bias: If True, use bias in linear layers. Default False following
LLaMA/PaLM convention for better training stability.
:param drop: Dropout probability applied after gating.
Example::
# Replace standard MLP in transformer
# Old: mlp = Mlp(768, 3072, 768)
# New:
mlp = SwiGLU(768, 3072, 768)
x = torch.randn(4, 196, 768)
out = mlp(x) # [4, 196, 768]
# Parameter count comparison
standard_mlp = nn.Sequential(
nn.Linear(768, 3072), nn.GELU(), nn.Linear(3072, 768)
)
swiglu = SwiGLU(768, 3072, 768)
print(sum(p.numel() for p in standard_mlp.parameters())) # 4,722,432
print(sum(p.numel() for p in swiglu.parameters())) # 4,722,432 (same!)
Note:
For best results, combine SwiGLU with:
- **LayerScale**: Stabilizes residual connections
- **QK-Norm**: Prevents attention explosion
- **RoPE**: Better positional generalization
This combination is used in LLaMA, PaLM, and modern vision transformers.
References:
- Shazeer, "GLU Variants Improve Transformer" (2020)
- Touvron et al., "LLaMA: Open and Efficient Foundation Language Models" (2023)
"""
def __init__(
self,
in_features: int,
hidden_features: int,
out_features: Optional[int] = None,
bias: bool = False,
drop: float = 0.0,
):
super().__init__()
out_features = out_features or in_features
# Scale hidden dim to maintain param parity with standard MLP
hidden_features = int(2 * hidden_features / 3)
self.w1 = nn.Linear(in_features, hidden_features, bias=bias)
self.w2 = nn.Linear(in_features, hidden_features, bias=bias)
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
self.drop = nn.Dropout(drop)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply SwiGLU transformation."""
# SwiGLU: (SiLU(xW₁) ⊙ xW₂) W₃
return self.drop(self.w3(F.silu(self.w1(x)) * self.w2(x)))
[docs]
class QKNorm(nn.Module):
"""Query-Key Normalization for attention stabilization.
Applies LayerNorm (without learnable parameters) independently to query
and key tensors before computing attention scores. This simple technique
dramatically improves training stability in deep transformers.
Why QK-Norm Works
-----------------
In deep transformers, attention logits (Q·Kᵀ) can grow unboundedly large,
causing:
- **Gradient explosion**: Large logits → extreme softmax → tiny gradients
- **Attention collapse**: All mass on single token
- **Training instability**: Requires very small learning rates
QK-Norm bounds the attention logits by normalizing Q and K to unit variance:
- Attention logits become bounded: ``|q·k| ≤ ||q|| ||k|| = O(√d)``
- Gradients remain stable throughout training
- Enables larger learning rates and faster convergence
Implementation Details
----------------------
- Uses **LayerNorm without learnable parameters** (γ=1, β=0)
- Normalizes per-head: applied to ``[..., head_dim]`` dimension
- Zero computational overhead in modern frameworks (fused with attention)
:param head_dim: Dimension per attention head. Each head is normalized
independently to preserve multi-head diversity.
Example::
# In attention forward pass
qk_norm = QKNorm(head_dim=64)
# q, k shape: [B, num_heads, seq_len, head_dim]
q, k = qk_norm(q, k)
# Now safe to compute attention
attn = (q @ k.transpose(-2, -1)) * scale
Example integration with Attention::
class Attention(nn.Module):
def __init__(self, dim, num_heads, use_qk_norm=True):
...
if use_qk_norm:
self.qk_norm = QKNorm(dim // num_heads)
def forward(self, x):
q, k, v = self.qkv(x).chunk(3, dim=-1)
if self.use_qk_norm:
q, k = self.qk_norm(q, k)
...
Note:
QK-Norm is especially important when combined with:
- **SwiGLU**: Gated activations can amplify hidden states
- **LayerScale**: Small initial residual scale needs stable attention
- **Deep networks**: Logit growth compounds with depth
Without QK-Norm, these combinations often fail to train or require
extensive hyperparameter tuning.
References:
- Henry et al., "Query-Key Normalization for Transformers" (EMNLP 2020)
- Dehghani et al., "Scaling Vision Transformers to 22B Parameters" (2023)
- Wortsman et al., "Small-scale proxies for large-scale Transformer training" (2023).
"""
def __init__(self, head_dim: int):
super().__init__()
# LayerNorm without learnable parameters (elementwise_affine=False)
self.q_norm = nn.LayerNorm(head_dim, elementwise_affine=False)
self.k_norm = nn.LayerNorm(head_dim, elementwise_affine=False)
[docs]
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Normalize query and key tensors independently.
:param q: Query tensor of shape ``[B, num_heads, seq_len, head_dim]``
or any shape with last dimension = head_dim
:param k: Key tensor of same shape as q
:return: Tuple of (normalized_q, normalized_k) with same shapes
Note:
Normalization is applied to the last dimension (head_dim).
Each head is normalized independently, preserving multi-head
representation diversity.
"""
return self.q_norm(q), self.k_norm(k)
[docs]
@dataclass
class MaskedEncoderOutput(ModelOutput):
"""Output from MaskedEncoder forward pass.
:ivar encoded: Encoded token representations (B, num_prefix + N_visible, D)
:ivar mask: Binary mask where 1 = masked, 0 = visible (B, N_patches)
:ivar ids_keep: Indices of visible patches (B, N_visible)
:ivar grid_size: Patch grid dimensions (height, width)
"""
encoded: torch.Tensor = None
mask: torch.Tensor = None
ids_keep: torch.Tensor = None
grid_size: Tuple[int, int] = None
[docs]
class MaskedEncoder(nn.Module):
"""Vision Transformer encoder with optional masking support.
Wraps a timm ViT model and adds flexible masking via :class:`PatchMasking`.
Handles all ViT internals: patch embedding, positional embeddings, prefix
tokens (CLS, registers), and transformer blocks.
:param model_or_model_name: timm model name string or pre-instantiated nn.Module
:param masking: PatchMasking instance. If None, no masking is applied.
:param pretrained: Load pretrained weights (only when model_or_model_name is str)
:param img_size: Override default image size
:param patch_size: Override default patch size (will reinitialize patch_embed)
:param dynamic_img_size: Enable dynamic image size support with pos_embed interpolation
Example::
from spt.backbone import PatchMasking, MaskedEncoder
masking = PatchMasking(mask_ratio=0.75, block_size=4)
encoder = MaskedEncoder(
model_or_model_name="vit_base_patch16_224",
masking=masking,
pretrained=True,
)
images = torch.randn(4, 3, 224, 224)
output = encoder(images)
print(output.encoded.shape) # (4, 1 + 49, 768) with 75% masking
print(output.mask.shape) # (4, 196)
print(output.ids_keep.shape) # (4, 49)
"""
def __init__(
self,
model_or_model_name: Union[str, nn.Module] = "vit_base_patch16_224",
masking: Optional[PatchMasking] = None,
pretrained: bool = False,
img_size: Optional[Union[int, Tuple[int, int]]] = None,
patch_size: Optional[Union[int, Tuple[int, int]]] = None,
dynamic_img_size: bool = False,
norm_layer: Optional[nn.Module] = None,
):
super().__init__()
self.dynamic_img_size = dynamic_img_size
self.masking = masking
# === Load or use provided encoder ===
if isinstance(model_or_model_name, str):
create_kwargs = {
"pretrained": pretrained,
"num_classes": 0,
"dynamic_img_size": dynamic_img_size,
}
if img_size is not None:
create_kwargs["img_size"] = img_size
if patch_size is not None:
create_kwargs["patch_size"] = patch_size
if pretrained:
print(
f"Warning: Changing patch_size to {patch_size} will reinitialize "
f"patch_embed weights. Pretrained weights won't fully apply."
)
if norm_layer is not None:
create_kwargs["norm_layer"] = norm_layer
self.vit = timm.create_model(model_or_model_name, **create_kwargs)
else:
logger.warning(
"MaskedEncoder received a pre-instantiated nn.Module. "
"Internals assume a timm ViT model with attributes such as "
"patch_embed, pos_embed, cls_token, blocks, norm, etc. "
"If you pass a non-timm module, unexpected errors may occur."
)
self.vit = model_or_model_name
if patch_size is not None:
self._rebuild_patch_embed(patch_size, img_size)
# Remove classification head if present
if hasattr(self.vit, "head") and hasattr(self.vit.head, "in_features"):
self.vit.head = nn.Identity()
# === Cache encoder properties ===
self.embed_dim = self.vit.embed_dim
self.patch_embed = self.vit.patch_embed
ps = self.patch_embed.patch_size
self.patch_size_h, self.patch_size_w = (ps, ps) if isinstance(ps, int) else ps
gs = self.patch_embed.grid_size
self.default_grid_h, self.default_grid_w = (
(gs, gs) if isinstance(gs, int) else gs
)
self.has_class_token = (
hasattr(self.vit, "cls_token") and self.vit.cls_token is not None
)
if hasattr(self.vit, "reg_token") and self.vit.reg_token is not None:
self.num_reg_tokens = self.vit.reg_token.shape[1]
else:
self.num_reg_tokens = getattr(self.vit, "num_reg_tokens", 0)
self.num_prefix_tokens = (
1 if self.has_class_token else 0
) + self.num_reg_tokens
self.no_embed_class = getattr(self.vit, "no_embed_class", False)
def _rebuild_patch_embed(
self,
patch_size: Union[int, Tuple[int, int]],
img_size: Optional[Union[int, Tuple[int, int]]] = None,
) -> None:
"""Rebuild patch embedding with new patch size."""
from timm.layers import PatchEmbed
old = self.vit.patch_embed
if img_size is None:
og, op = old.grid_size, old.patch_size
img_size = (
(og[0] * op[0], og[1] * op[1]) if isinstance(og, tuple) else og * op
)
self.vit.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=old.proj.in_channels,
embed_dim=old.proj.out_channels,
)
if old.num_patches != self.vit.patch_embed.num_patches:
self._resize_pos_embed(self.vit.patch_embed.grid_size)
def _resize_pos_embed(self, new_grid_size: Tuple[int, int]) -> None:
"""Resize positional embeddings to new grid size."""
old_pos = self.vit.pos_embed
if old_pos is None:
return
num_prefix = self.num_prefix_tokens if not self.no_embed_class else 0
src_patches = old_pos.shape[1] - num_prefix
src_size = int(src_patches**0.5)
new_pos = interpolate_pos_embed(
old_pos, (src_size, src_size), new_grid_size, num_prefix
)
self.vit.pos_embed = nn.Parameter(new_pos)
def _get_grid_size(self, images: torch.Tensor) -> Tuple[int, int]:
"""Compute patch grid size from image dimensions."""
H, W = images.shape[-2:]
return H // self.patch_size_h, W // self.patch_size_w
def _get_pos_embed(
self, grid_h: int, grid_w: int
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
"""Get positional embeddings, interpolating if needed for dynamic size."""
pos_embed = self.vit.pos_embed
if pos_embed is None:
return None, None
num_prefix = self.num_prefix_tokens if not self.no_embed_class else 0
if self.dynamic_img_size and (
grid_h != self.default_grid_h or grid_w != self.default_grid_w
):
src_patches = pos_embed.shape[1] - num_prefix
src_size = int(src_patches**0.5)
pos_embed = interpolate_pos_embed(
pos_embed, (src_size, src_size), (grid_h, grid_w), num_prefix
)
if self.no_embed_class:
return None, pos_embed
return (
pos_embed[:, : self.num_prefix_tokens],
pos_embed[:, self.num_prefix_tokens :],
)
def _get_prefix_tokens(self, B: int) -> Optional[torch.Tensor]:
"""Get CLS and register tokens expanded to batch size."""
tokens = []
if self.has_class_token:
tokens.append(self.vit.cls_token.expand(B, -1, -1))
if self.num_reg_tokens > 0:
tokens.append(self.vit.reg_token.expand(B, -1, -1))
return torch.cat(tokens, dim=1) if tokens else None
[docs]
def forward(self, images: torch.Tensor) -> MaskedEncoderOutput:
"""Encode images with optional masking.
:param images: Input images (B, C, H, W)
:return: MaskedEncoderOutput with encoded tokens and mask info
"""
B = images.shape[0]
device = images.device
grid_h, grid_w = self._get_grid_size(images)
num_patches = grid_h * grid_w
# Patch embed + positional embed
x = self.patch_embed(images)
if x.ndim == 4:
x = x.reshape(B, -1, x.shape[-1])
prefix_pos, patch_pos = self._get_pos_embed(grid_h, grid_w)
if patch_pos is not None:
x = x + patch_pos
# Apply masking (training only)
if self.training and self.masking is not None:
mask_out = self.masking(x, grid_h, grid_w)
x = mask_out.visible
mask = mask_out.mask
ids_keep = mask_out.ids_keep
else:
mask = torch.zeros(B, num_patches, device=device)
ids_keep = (
torch.arange(num_patches, device=device).unsqueeze(0).expand(B, -1)
)
# Prepend prefix tokens
prefix = self._get_prefix_tokens(B)
if prefix is not None:
if prefix_pos is not None and not self.no_embed_class:
prefix = prefix + prefix_pos
x = torch.cat([prefix, x], dim=1)
# Transformer blocks
x = self.vit.pos_drop(x)
blocks = self.vit.blocks if hasattr(self.vit, "blocks") else self.vit.layers
if isinstance(blocks, nn.ModuleList):
for blk in blocks:
x = blk(x)
else:
x = blocks(x)
x = self.vit.norm(x)
return MaskedEncoderOutput(
encoded=x,
mask=mask,
ids_keep=ids_keep,
grid_size=(grid_h, grid_w),
)
[docs]
def forward_features(self, images: torch.Tensor) -> torch.Tensor:
"""Encode without masking (for inference)."""
was_training = self.training
self.eval()
with torch.no_grad():
output = self.forward(images)
if was_training:
self.train()
return output.encoded
[docs]
class Attention(nn.Module):
"""Multi-head self-attention with efficient SDPA backend.
Supports modern transformer features including Rotary Position Embeddings
(RoPE) and Query-Key Normalization (QK-Norm) for improved training stability
and positional generalization.
Uses ``F.scaled_dot_product_attention`` which automatically selects the
optimal backend:
- **Flash Attention** (when available, fastest)
- **Memory-efficient attention** (xformers-style)
- **Math fallback**
Architecture Features
---------------------
**RoPE (Rotary Position Embedding)**:
Encodes relative 2D positions via complex rotations applied to Q and K.
Unlike additive positional embeddings, RoPE:
- Naturally captures relative positions
- Generalizes to unseen sequence lengths
- Requires no extra parameters
Enable with ``use_rope=True``. Requires ``grid_size`` in forward().
**QK-Norm (Query-Key Normalization)**:
Applies LayerNorm (without learnable params) to Q and K before
computing attention scores. Benefits:
- Prevents attention logit explosion in deep networks
- Stabilizes training without extra hyperparameter tuning
- Essential when combined with SwiGLU/LayerScale
Enable with ``use_qk_norm=True``.
Attention Masking
-----------------
Supports flexible attention patterns via ``attn_mask``:
- **Causal (autoregressive)**: ``torch.triu(ones, diagonal=1)``
- **Bidirectional**: ``attn_mask=None``
- **Block sparse**: Custom boolean masks
- **Leave-one-out**: ``torch.eye(N)`` (each token ignores itself)
Mask convention: ``True`` = blocked (cannot attend), ``False`` = allowed.
:param dim: Input/output embedding dimension
:param num_heads: Number of parallel attention heads. Must divide ``dim``.
:param qkv_bias: If True, add learnable bias to Q, K, V projections.
Default True following ViT convention.
:param attn_drop: Dropout probability on attention weights. Applied only
during training.
:param proj_drop: Dropout probability on output projection.
:param use_rope: Enable 2D Rotary Position Embedding. When True, position
information is encoded via rotation in attention rather than additive
embeddings. Requires ``grid_size`` parameter in forward().
:param use_qk_norm: Enable Query-Key normalization. Applies LayerNorm
(without learnable parameters) to Q and K tensors before attention.
Recommended for deep networks or when using SwiGLU.
:param max_grid_size: Maximum spatial grid size for RoPE frequency cache.
Only used when ``use_rope=True``. Set to largest expected grid dimension.
Example::
# Standard attention
attn = Attention(dim=768, num_heads=12)
out = attn(x) # [B, N, 768]
# With RoPE for vision (requires grid_size)
attn = Attention(dim=768, num_heads=12, use_rope=True)
out = attn(x, grid_size=(14, 14))
# With QK-Norm for training stability
attn = Attention(dim=768, num_heads=12, use_qk_norm=True)
out = attn(x)
# Causal attention (autoregressive)
N = x.shape[1]
causal_mask = torch.triu(torch.ones(N, N, dtype=torch.bool), diagonal=1)
out = attn(x, attn_mask=causal_mask)
# NEPA-style: RoPE + QK-Norm + causal
attn = Attention(dim=768, num_heads=12, use_rope=True, use_qk_norm=True)
causal_mask = torch.triu(
torch.ones(N, N, dtype=torch.bool, device=x.device), diagonal=1
)
out = attn(x, attn_mask=causal_mask, grid_size=(14, 14))
Note:
When ``use_rope=True``, do NOT add positional embeddings to input tokens.
RoPE encodes positions internally via Q/K rotation.
References:
- RoPE: Su et al., "RoFormer: Enhanced Transformer with Rotary
Position Embedding" (2021)
- QK-Norm: Henry et al., "Query-Key Normalization for Transformers" (2020)
- Flash Attention: Dao et al., "FlashAttention: Fast and Memory-Efficient
Exact Attention" (2022)
"""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
use_rope: bool = False,
use_qk_norm: bool = False,
max_grid_size: int = 32,
):
super().__init__()
if dim % num_heads != 0:
raise ValueError(
f"dim ({dim}) must be divisible by num_heads ({num_heads})"
)
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.use_rope = use_rope
self.use_qk_norm = use_qk_norm
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = attn_drop
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
if use_rope:
self.rope = RotaryPositionEmbedding2D(self.head_dim, max_grid_size)
if use_qk_norm:
self.qk_norm = QKNorm(self.head_dim)
[docs]
def forward(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
grid_size: Optional[Tuple[int, int]] = None,
) -> torch.Tensor:
"""Compute multi-head self-attention.
:param x: Input tensor of shape ``[B, N, D]`` where B is batch size,
N is sequence length, and D is embedding dimension.
:param attn_mask: Optional attention mask. Supported shapes:
- ``[N, N]``: Same mask for all batches and heads
- ``[B, N, N]``: Per-batch mask, broadcast over heads
- ``[B, H, N, N]``: Full per-batch, per-head mask
Values: ``True`` = blocked (cannot attend), ``False`` = allowed.
Common patterns:
- Causal: ``torch.triu(torch.ones(N, N, dtype=torch.bool), diagonal=1)``
- Leave-one-out: ``torch.eye(N, dtype=torch.bool)``
:param grid_size: Spatial grid dimensions as ``(height, width)``.
**Required when** ``use_rope=True``. Used to compute 2D rotary
position embeddings. For a 224x224 image with patch_size=16,
use ``grid_size=(14, 14)``.
:return: Output tensor of shape ``[B, N, D]``
:raises ValueError: If ``use_rope=True`` but ``grid_size`` is None
"""
B, N, C = x.shape
# Fused QKV projection
qkv = (
self.qkv(x)
.reshape(B, N, 3, self.num_heads, self.head_dim)
.permute(2, 0, 3, 1, 4)
)
q, k, v = qkv.unbind(0) # Each: [B, num_heads, N, head_dim]
# Apply QK-Norm before RoPE (order matters for stability)
if self.use_qk_norm:
q, k = self.qk_norm(q, k)
# Apply 2D Rotary Position Embedding
if self.use_rope:
if grid_size is None:
# Attempt to infer square grid
grid_h = grid_w = int(N**0.5)
if grid_h * grid_w != N:
raise ValueError(
f"use_rope=True requires grid_size parameter for non-square "
f"sequences. Got N={N} which is not a perfect square."
)
else:
grid_h, grid_w = grid_size
q, k = self.rope(q, k, grid_h, grid_w)
# Convert boolean mask to SDPA-compatible float mask
attn_mask_sdpa = self._prepare_attn_mask(attn_mask, q.dtype, q.device)
# Efficient attention via SDPA
x = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=attn_mask_sdpa,
dropout_p=self.attn_drop if self.training else 0.0,
)
# Reshape and project output
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def _prepare_attn_mask(
self,
attn_mask: Optional[torch.Tensor],
dtype: torch.dtype,
device: torch.device,
) -> Optional[torch.Tensor]:
"""Convert boolean attention mask to SDPA-compatible float mask.
SDPA expects additive mask where ``-inf`` blocks attention.
Our convention: ``True`` = blocked, ``False`` = allowed.
"""
if attn_mask is None:
return None
if attn_mask.dtype == torch.bool:
mask = torch.zeros_like(attn_mask, dtype=dtype, device=device)
mask = mask.masked_fill(attn_mask, float("-inf"))
else:
mask = attn_mask.to(dtype=dtype, device=device)
# Expand to [B, H, N, N] for SDPA broadcasting
while mask.dim() < 4:
mask = mask.unsqueeze(0 if mask.dim() == 2 else 1)
return mask
[docs]
class CrossAttention(nn.Module):
"""Multi-head cross-attention with efficient SDPA backend.
Queries attend to key-value pairs from a separate context sequence.
Supports attention masking to block specific query-key interactions.
:param dim: Query dimension
:param context_dim: Context dimension (defaults to dim)
:param num_heads: Number of attention heads
:param qkv_bias: Add bias to projections
:param attn_drop: Attention dropout rate
:param proj_drop: Output projection dropout rate
Example::
cross_attn = CrossAttention(dim=768, context_dim=1024, num_heads=12)
# Standard cross-attention
out = cross_attn(queries, context) # [B, N, 768]
# Masked cross-attention: block certain query-context pairs
# mask[i, j] = True means query i cannot attend to context j
mask = torch.zeros(N, M, dtype=torch.bool)
mask[:, :10] = True # block attention to first 10 context tokens
out = cross_attn(queries, context, attn_mask=mask)
"""
def __init__(
self,
dim: int,
context_dim: Optional[int] = None,
num_heads: int = 8,
qkv_bias: bool = True,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
):
super().__init__()
context_dim = context_dim or dim
assert dim % num_heads == 0, "dim must be divisible by num_heads"
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.kv = nn.Linear(context_dim, dim * 2, bias=qkv_bias)
self.attn_drop = attn_drop
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
[docs]
def forward(
self,
x: torch.Tensor,
context: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass.
:param x: Query tensor [B, N, D]
:param context: Key-value tensor [B, M, context_dim]
:param attn_mask: Cross-attention mask. Can be one of:
- [N, M]: Same mask for all batches and heads
- [B, N, M]: Per-batch mask, broadcast over heads
- [B, H, N, M]: Full per-batch, per-head mask
Mask values: True = blocked (cannot attend), False = allowed.
:return: Output tensor [B, N, D]
"""
B, N, C = x.shape
M = context.shape[1]
# Query projection: [B, N, D] -> [B, H, N, head_dim]
q = self.q(x).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
# KV projection: [B, M, D] -> [B, H, M, head_dim] x2
kv = (
self.kv(context)
.reshape(B, M, 2, self.num_heads, self.head_dim)
.permute(2, 0, 3, 1, 4)
)
k, v = kv.unbind(0)
# Convert mask to SDPA format if provided
attn_mask_sdpa = self._prepare_attn_mask(attn_mask, q.dtype, q.device)
# Efficient attention
x = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=attn_mask_sdpa,
dropout_p=self.attn_drop if self.training else 0.0,
)
# Reshape back
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def _prepare_attn_mask(
self,
attn_mask: Optional[torch.Tensor],
dtype: torch.dtype,
device: torch.device,
) -> Optional[torch.Tensor]:
"""Convert boolean attention mask to SDPA-compatible float mask.
SDPA expects additive mask where -inf blocks attention.
Our convention: True = blocked, False = allowed.
:param attn_mask: Boolean mask [N, M], [B, N, M], or [B, H, N, M]
:param dtype: Target dtype for the mask
:param device: Target device for the mask
:return: Float mask suitable for SDPA, or None
"""
if attn_mask is None:
return None
# Convert bool mask to float: True -> -inf, False -> 0
if attn_mask.dtype == torch.bool:
mask = torch.zeros_like(attn_mask, dtype=dtype, device=device)
mask = mask.masked_fill(attn_mask, float("-inf"))
else:
mask = attn_mask.to(dtype=dtype, device=device)
# Expand dimensions for broadcasting: need [B, H, N, M] for SDPA
# [N, M] -> [1, 1, N, M]
# [B, N, M] -> [B, 1, N, M]
# [B, H, N, M] -> unchanged
while mask.dim() < 4:
mask = mask.unsqueeze(0 if mask.dim() == 2 else 1)
return mask
# =============================================================================
# Transformer Block
# =============================================================================
[docs]
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
"""Apply AdaLN modulation: x * (1 + scale) + shift."""
return x * (1 + scale) + shift
[docs]
class MAEDecoder(nn.Module):
"""MAE-style Vision Transformer Decoder using FlexibleTransformer.
Implements the decoder component of Masked Autoencoders (MAE) [1]_ for
self-supervised visual representation learning. The decoder reconstructs
masked patches from visible patch embeddings using joint self-attention,
where visible tokens and learnable mask tokens attend to each other.
The decoder is intentionally lightweight compared to the encoder, as MAE
demonstrates that a shallow decoder is sufficient for pixel reconstruction
while keeping the encoder focused on learning semantic representations.
Architecture Overview
---------------------
1. **Input projection**: Maps encoder embeddings (embed_dim) to decoder
dimension (decoder_embed_dim)
2. **Mask token expansion**: Learnable mask tokens are placed at masked
positions
3. **Positional encoding**: Adds position information to all tokens
4. **Transformer blocks**: Joint self-attention over visible + mask tokens
5. **Output projection**: Maps to output_dim (typically patch_size² × channels)
Parameters
----------
embed_dim : int, default=768
Embedding dimension from the encoder. This is the input dimension
of visible tokens passed to the decoder.
decoder_embed_dim : int, default=512
Internal hidden dimension of the decoder transformer blocks.
Typically smaller than embed_dim for efficiency.
output_dim : int, default=768
Output dimension per token. For pixel reconstruction, this should be
``patch_size ** 2 * in_channels`` (e.g., 16×16×3 = 768 for RGB).
num_patches : int, default=196
Total number of patches T in the image (e.g., 14×14 = 196 for
224×224 images with patch_size=16).
depth : int, default=4
Number of transformer blocks in the decoder. MAE typically uses
fewer blocks than the encoder (e.g., 4-8 vs 12-24).
num_heads : int, default=16
Number of attention heads in multi-head self-attention.
mlp_ratio : float, default=4.0
Expansion ratio for the MLP hidden dimension relative to
decoder_embed_dim.
pos_embed_type : {'sincos_1d', 'sincos_2d', 'learned'}, default='sincos_2d'
Type of positional embedding:
- 'sincos_2d': Fixed 2D sinusoidal (recommended for images)
- 'sincos_1d': Fixed 1D sinusoidal
- 'learned': Learnable positional embeddings
grid_size : int, optional
Spatial grid size for 2D positional embeddings. If None, inferred
as ``int(sqrt(num_patches))``. Required for non-square grids.
drop_path_rate : float, default=0.0
Stochastic depth rate for regularization during training.
Attributes:
----------
mask_token : nn.Parameter
Learnable token of shape (1, 1, embed_dim) used to represent
masked positions. Initialized with truncated normal (std=0.02).
transformer : FlexibleTransformer
Core transformer module handling attention and projections.
Notes:
-----
- The mask convention follows MAE: **0 = visible/kept, 1 = masked**
- The decoder receives visible tokens and reconstructs masked positions
- For efficiency, only masked positions are predicted by default
References:
----------
.. [1] He, K., et al. "Masked Autoencoders Are Scalable Vision Learners."
CVPR 2022. https://arxiv.org/abs/2111.06377
Examples:
--------
**Basic Usage with MAE Encoder**
>>> import torch
>>> import torch.nn as nn
>>>
>>> # Configuration matching ViT-Base
>>> B, T = 4, 196 # batch size, num_patches (14x14)
>>> embed_dim = 768 # encoder dimension
>>> mask_ratio = 0.75 # MAE default: mask 75% of patches
>>>
>>> # Initialize decoder
>>> decoder = MAEDecoder(
... embed_dim=embed_dim,
... decoder_embed_dim=512,
... output_dim=16 * 16 * 3, # patch_size² × channels = 768
... num_patches=T,
... depth=4,
... num_heads=16,
... )
>>>
>>> # Simulate encoder output (visible tokens only)
>>> N_vis = int(T * (1 - mask_ratio)) # 49 visible patches
>>> visible_tokens = torch.randn(B, N_vis, embed_dim)
>>>
>>> # Create random mask (0=visible, 1=masked)
>>> mask = torch.zeros(B, T)
>>> for i in range(B):
... masked_indices = torch.randperm(T)[: T - N_vis]
... mask[i, masked_indices] = 1
>>>
>>> # Decode - predict masked patches only
>>> pred_masked = decoder(visible_tokens, mask, output_masked_only=True)
>>> print(pred_masked.shape) # [B, N_mask, output_dim]
torch.Size([4, 147, 768])
**Full Sequence Reconstruction**
>>> # Get predictions for ALL positions (for visualization)
>>> pred_full = decoder(visible_tokens, mask, output_masked_only=False)
>>> print(pred_full.shape) # [B, T, output_dim]
torch.Size([4, 196, 768])
**Using Full Sequence Input**
If you have the full sequence with mask tokens already inserted:
>>> full_sequence = torch.randn(B, T, embed_dim) # [B, 196, 768]
>>> pred = decoder(full_sequence, mask, output_masked_only=True)
>>> print(pred.shape)
torch.Size([4, 147, 768])
**Integration with MAE Training Loop**
>>> # Typical MAE training step (pseudocode)
>>> def mae_forward(encoder, decoder, images, mask_ratio=0.75):
... # Patchify and mask
... patches = patchify(images) # [B, T, patch_dim]
... mask = random_mask(B, T, mask_ratio) # [B, T], 0=keep, 1=mask
...
... # Encode visible patches only
... visible_patches = patches[~mask.bool()].reshape(B, -1, patch_dim)
... latent = encoder(visible_patches) # [B, N_vis, embed_dim]
...
... # Decode to predict masked patches
... pred = decoder(
... latent, mask, output_masked_only=True
... ) # [B, N_mask, output_dim]
...
... # Reconstruction loss on masked patches only
... target = patches[mask.bool()].reshape(B, -1, patch_dim)
... loss = F.mse_loss(pred, target)
... return loss
**Custom Configuration for ViT-Large**
>>> decoder_large = MAEDecoder(
... embed_dim=1024, # ViT-L encoder dim
... decoder_embed_dim=512, # Keep decoder lightweight
... output_dim=768, # 16×16×3 pixels
... num_patches=256, # 16×16 patches for 256×256 images
... depth=8, # Slightly deeper
... num_heads=16,
... pos_embed_type="sincos_2d",
... drop_path_rate=0.1, # Regularization
... )
See Also:
--------
FlexibleTransformer : Core transformer implementation used internally.
"""
def __init__(
self,
embed_dim: int = 768,
decoder_embed_dim: int = 512,
output_dim: int = 768,
num_patches: int = 196,
depth: int = 4,
num_heads: int = 16,
mlp_ratio: float = 4.0,
pos_embed_type: Literal["sincos_1d", "sincos_2d", "learned"] = "sincos_2d",
grid_size: Optional[int] = None,
drop_path_rate: float = 0.0,
):
super().__init__()
self.num_patches = num_patches
# Learnable mask token
self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
trunc_normal_(self.mask_token, std=0.02)
# Core transformer
self.transformer = FlexibleTransformer(
input_dim=embed_dim,
hidden_dim=decoder_embed_dim,
output_dim=output_dim,
num_patches=num_patches,
depth=depth,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
self_attn=True,
cross_attn=False,
use_adaln=False,
pos_embed_type=pos_embed_type,
grid_size=grid_size,
drop_path_rate=drop_path_rate,
zero_init_output=False,
num_prefix_tokens=0,
)
[docs]
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor,
ids_keep: torch.Tensor | None = None,
output_masked_only: bool = False,
) -> torch.Tensor:
"""Forward pass.
:param x: Visible tokens [B, N_vis, D] or full sequence [B, T, D]
:param mask: Binary mask [B, T], 0=kept, 1=masked
:param ids_keep: Indices of kept (visible) patches (B, N_keep)
:param output_masked_only: If True, return [B, N_mask, D].
If False, return [B, T, D].
:return: Predictions
"""
B, T = mask.shape
mask_bool = mask.bool() # Convert once, use everywhere
n_vis_per = (~mask_bool).sum(dim=1)
n_mask_per = mask_bool.sum(dim=1)
assert torch.all(n_vis_per == n_vis_per[0]), (
"Number of visible patches must be the same for all samples"
)
N_vis = int(n_vis_per[0].item())
N_mask = int(n_mask_per[0].item())
if N_mask == 0:
# visible idx is all patches in order
visible_idx = torch.arange(T, device=mask.device).unsqueeze(0).expand(B, -1)
visible_tokens = x if x.shape[1] == T else x
out = self.transformer(
context=visible_tokens,
queries=visible_tokens.new_empty(B, 0, visible_tokens.shape[-1]),
context_idx=visible_idx,
query_idx=visible_idx[:, :0],
return_all=True,
)
if output_masked_only:
return out[:, :0]
return out
# Get indices (sort False/0 before True/1, so visible indices come first)
visible_idx = torch.argsort(mask_bool.int(), dim=1, stable=True)[:, :N_vis]
masked_idx = torch.argsort(mask_bool.int(), dim=1, stable=True)[:, N_vis:]
# Get visible tokens
if x.shape[1] == T:
visible_tokens = torch.gather(
x, dim=1, index=visible_idx.unsqueeze(-1).expand(-1, -1, x.shape[-1])
)
else:
if ids_keep is not None:
visible_idx = ids_keep
visible_tokens = x
order = torch.argsort(mask_bool.int(), dim=1, stable=True)
masked_idx = order[:, N_vis:]
# Mask tokens for masked positions
mask_tokens = self.mask_token.expand(B, N_mask, -1)
return self.transformer(
context=visible_tokens,
queries=mask_tokens,
context_idx=visible_idx,
query_idx=masked_idx,
return_all=not output_masked_only,
)
[docs]
class PositionalEncoding2D(nn.Module):
"""Flexible 2D positional encoding for vision transformers."""
def __init__(
self,
embed_dim: int,
grid_size: Tuple[int, int],
pos_type: Literal["learnable", "sinusoidal", "rope", "none"] = "learnable",
num_prefix_tokens: int = 1,
learnable: Optional[
bool
] = None, # Override: force learnable even for sinusoidal
):
"""Positional encoding for 2d input.
:param embed_dim: Embedding dimension
:param grid_size: (H, W) grid size in patches
:param pos_type: Type of positional encoding
:param num_prefix_tokens: Number of prefix tokens (CLS + registers)
:param learnable: If True, make sinusoidal learnable; if None, use default
"""
super().__init__()
self.embed_dim = embed_dim
self.grid_h, self.grid_w = grid_size
self.num_patches = self.grid_h * self.grid_w
self.pos_type = pos_type
self.num_prefix_tokens = num_prefix_tokens
# Override learnable if specified
if learnable is not None:
self.is_learnable = learnable
else:
self.is_learnable = pos_type == "learnable"
if pos_type == "none":
# No positional encoding
self.pos_embed = None
elif pos_type == "learnable":
# Learnable absolute positional embeddings
self.pos_embed = nn.Parameter(
torch.zeros(1, num_prefix_tokens + self.num_patches, embed_dim)
)
nn.init.trunc_normal_(self.pos_embed, std=0.02)
elif pos_type == "sinusoidal":
# 2D sinusoidal positional embeddings
pos_embed = self._build_sinusoidal_2d(embed_dim, self.grid_h, self.grid_w)
# Add prefix token positions (zeros or learned separately)
prefix_pos = torch.zeros(1, num_prefix_tokens, embed_dim)
pos_embed = torch.cat([prefix_pos, pos_embed], dim=1)
if self.is_learnable:
self.pos_embed = nn.Parameter(pos_embed)
else:
self.register_buffer("pos_embed", pos_embed)
elif pos_type == "rope":
# RoPE doesn't use additive embeddings
self.pos_embed = None
# Precompute RoPE frequencies
self.register_buffer(
"freqs_h", self._build_rope_freqs(embed_dim // 4, self.grid_h)
)
self.register_buffer(
"freqs_w", self._build_rope_freqs(embed_dim // 4, self.grid_w)
)
else:
raise ValueError(f"Unknown pos_type: {pos_type}")
def _build_sinusoidal_2d(
self, embed_dim: int, grid_h: int, grid_w: int
) -> torch.Tensor:
"""Build 2D sinusoidal positional embeddings."""
assert embed_dim % 4 == 0, "embed_dim must be divisible by 4 for 2D sinusoidal"
dim_h = embed_dim // 2
dim_w = embed_dim // 2
# Height positions
pos_h = torch.arange(grid_h).unsqueeze(1) # [H, 1]
dim_t_h = torch.arange(0, dim_h, 2).float() # [dim_h/2]
omega_h = 1.0 / (10000 ** (dim_t_h / dim_h))
pos_embed_h = torch.zeros(grid_h, dim_h)
pos_embed_h[:, 0::2] = torch.sin(pos_h * omega_h)
pos_embed_h[:, 1::2] = torch.cos(pos_h * omega_h)
# Width positions
pos_w = torch.arange(grid_w).unsqueeze(1) # [W, 1]
dim_t_w = torch.arange(0, dim_w, 2).float()
omega_w = 1.0 / (10000 ** (dim_t_w / dim_w))
pos_embed_w = torch.zeros(grid_w, dim_w)
pos_embed_w[:, 0::2] = torch.sin(pos_w * omega_w)
pos_embed_w[:, 1::2] = torch.cos(pos_w * omega_w)
# Combine: [H, W, D]
pos_embed_h = pos_embed_h.unsqueeze(1).expand(-1, grid_w, -1) # [H, W, dim_h]
pos_embed_w = pos_embed_w.unsqueeze(0).expand(grid_h, -1, -1) # [H, W, dim_w]
pos_embed = torch.cat([pos_embed_h, pos_embed_w], dim=-1) # [H, W, D]
pos_embed = pos_embed.reshape(1, grid_h * grid_w, embed_dim) # [1, H*W, D]
return pos_embed
def _build_rope_freqs(
self, dim: int, max_seq_len: int, base: float = 10000.0
) -> torch.Tensor:
"""Build RoPE frequency tensor."""
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
pos = torch.arange(max_seq_len)
freqs = torch.einsum("i,j->ij", pos, inv_freq) # [seq_len, dim/2]
freqs = torch.cat([freqs, freqs], dim=-1) # [seq_len, dim]
return freqs
def _apply_rope_2d(self, x: torch.Tensor, grid_h: int, grid_w: int) -> torch.Tensor:
"""Apply 2D RoPE to patch tokens."""
B, N, D = x.shape
# Separate prefix and patch tokens
prefix = x[:, : self.num_prefix_tokens, :]
patches = x[:, self.num_prefix_tokens :, :] # [B, H*W, D]
# Reshape to 2D grid
patches = patches.reshape(B, grid_h, grid_w, D)
# Split embedding into 4 parts for 2D RoPE
d_quarter = D // 4
x1, x2, x3, x4 = patches.split(d_quarter, dim=-1)
# Get frequencies (interpolate if needed)
freqs_h = self.freqs_h[:grid_h, :d_quarter] # [H, d_quarter]
freqs_w = self.freqs_w[:grid_w, :d_quarter] # [W, d_quarter]
# Apply rotation to height dimension (x1, x2)
cos_h = torch.cos(freqs_h).unsqueeze(1) # [H, 1, d_quarter]
sin_h = torch.sin(freqs_h).unsqueeze(1) # [H, 1, d_quarter]
x1_rot = x1 * cos_h - x2 * sin_h
x2_rot = x1 * sin_h + x2 * cos_h
# Apply rotation to width dimension (x3, x4)
cos_w = torch.cos(freqs_w).unsqueeze(0) # [1, W, d_quarter]
sin_w = torch.sin(freqs_w).unsqueeze(0) # [1, W, d_quarter]
x3_rot = x3 * cos_w - x4 * sin_w
x4_rot = x3 * sin_w + x4 * cos_w
# Combine
patches = torch.cat([x1_rot, x2_rot, x3_rot, x4_rot], dim=-1)
patches = patches.reshape(B, grid_h * grid_w, D)
# Recombine with prefix (prefix tokens don't get RoPE)
return torch.cat([prefix, patches], dim=1)
[docs]
def forward(
self, x: torch.Tensor, grid_size: Optional[Tuple[int, int]] = None
) -> torch.Tensor:
"""Apply positional encoding.
:param x: [B, num_prefix + num_patches, D]
:param grid_size: (H, W) if different from default (for dynamic size)
:return: x with positional encoding applied
"""
if self.pos_type == "none":
return x
grid_h = grid_size[0] if grid_size else self.grid_h
grid_w = grid_size[1] if grid_size else self.grid_w
if self.pos_type == "rope":
return self._apply_rope_2d(x, grid_h, grid_w)
# Additive positional embeddings (learnable or sinusoidal)
pos_embed = self.pos_embed
# Interpolate if dynamic size
if grid_h != self.grid_h or grid_w != self.grid_w:
pos_embed = self._interpolate(pos_embed, grid_h, grid_w)
return x + pos_embed
def _interpolate(
self, pos_embed: torch.Tensor, target_h: int, target_w: int
) -> torch.Tensor:
"""Interpolate positional embeddings to new grid size."""
prefix_pos = pos_embed[:, : self.num_prefix_tokens, :]
patch_pos = pos_embed[:, self.num_prefix_tokens :, :]
D = patch_pos.shape[-1]
patch_pos = patch_pos.reshape(1, self.grid_h, self.grid_w, D).permute(
0, 3, 1, 2
)
patch_pos = F.interpolate(
patch_pos, size=(target_h, target_w), mode="bicubic", align_corners=False
)
patch_pos = patch_pos.permute(0, 2, 3, 1).reshape(1, target_h * target_w, D)
return torch.cat([prefix_pos, patch_pos], dim=1)