Source code for stable_pretraining.backbone.pos_embed

"""Positional embedding utilities for vision transformers."""

import torch
import torch.nn.functional as F
from typing import Literal
import math

__all__ = [
    "get_sincos_pos_embed",
    "get_1d_sincos_pos_embed",
    "get_2d_sincos_pos_embed",
    "interpolate_pos_embed",
    "get_timestep_embed",
]


[docs] def get_timestep_embed( t: torch.Tensor, dim: int, max_period: int = 10000 ) -> torch.Tensor: """Generate sinusoidal embeddings for continuous timesteps. Unlike positional embeddings for sequences, this embeds scalar timestep values. Used for diffusion/flow matching time conditioning. :param t: Timestep values (B,) or (B, 1), typically in [0, 1] :param dim: Embedding dimension :param max_period: Maximum period for frequency scaling :return: Timestep embeddings of shape (B, dim) """ t = t.view(-1).float() half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(half, device=t.device, dtype=t.dtype) / half ) args = t[:, None] * freqs[None, :] embedding = torch.cat([args.cos(), args.sin()], dim=-1) if dim % 2: embedding = F.pad(embedding, (0, 1)) return embedding
[docs] def get_1d_sincos_pos_embed( embed_dim: int, length: int, cls_token: bool = False, ) -> torch.Tensor: """Generate 1D sinusoidal positional embeddings. :param embed_dim: Embedding dimension :param length: Sequence length (number of positions) :param cls_token: If True, prepend a zero embedding for CLS token :return: Positional embeddings of shape (length, embed_dim) or (length + 1, embed_dim) if cls_token=True """ if embed_dim <= 0: raise ValueError(f"embed_dim must be positive, got {embed_dim}") if length <= 0: raise ValueError(f"length must be positive, got {length}") pos = torch.arange(length, dtype=torch.float32).unsqueeze(1) dim = torch.arange(0, embed_dim, 2, dtype=torch.float32) inv_freq = 1.0 / (10000 ** (dim / embed_dim)) pe = torch.zeros(length, embed_dim) pe[:, 0::2] = torch.sin(pos * inv_freq) pe[:, 1::2] = torch.cos(pos * inv_freq[: embed_dim // 2]) if cls_token: pe = torch.cat([torch.zeros(1, embed_dim), pe], dim=0) return pe
[docs] def get_2d_sincos_pos_embed( embed_dim: int, grid_size: int | tuple[int, int], cls_token: bool = False, ) -> torch.Tensor: """Generate 2D sinusoidal positional embeddings for image patches. :param embed_dim: Embedding dimension (must be divisible by 4) :param grid_size: Grid height/width as int (square) or (height, width) tuple :param cls_token: If True, prepend a zero embedding for CLS token :return: Positional embeddings of shape (H*W, embed_dim) or (H*W + 1, embed_dim) if cls_token=True """ if embed_dim <= 0 or embed_dim % 4 != 0: raise ValueError( f"embed_dim must be positive and divisible by 4, got {embed_dim}" ) if isinstance(grid_size, int): grid_h = grid_w = grid_size else: grid_h, grid_w = grid_size if grid_h <= 0 or grid_w <= 0: raise ValueError(f"grid dimensions must be positive, got ({grid_h}, {grid_w})") grid_y = torch.arange(grid_h, dtype=torch.float32) grid_x = torch.arange(grid_w, dtype=torch.float32) grid = torch.meshgrid(grid_y, grid_x, indexing="ij") grid = torch.stack(grid, dim=-1).reshape(-1, 2) dim = embed_dim // 4 omega = torch.arange(dim, dtype=torch.float32) / dim omega = 1.0 / (10000**omega) out_h = grid[:, 0:1] @ omega.unsqueeze(0) out_w = grid[:, 1:2] @ omega.unsqueeze(0) pe = torch.cat( [torch.sin(out_h), torch.cos(out_h), torch.sin(out_w), torch.cos(out_w)], dim=1, ) if cls_token: pe = torch.cat([torch.zeros(1, embed_dim), pe], dim=0) return pe
[docs] def get_sincos_pos_embed( embed_dim: int, num_patches: int, mode: Literal["1d", "2d"] = "1d", grid_size: int | tuple[int, int] | None = None, cls_token: bool = False, ) -> torch.Tensor: """Unified interface for generating sinusoidal positional embeddings. :param embed_dim: Embedding dimension :param num_patches: Total number of patches (used for 1d mode) :param mode: Embedding type - '1d' for sequence, '2d' for image grid :param grid_size: Required for '2d' mode :param cls_token: If True, prepend a zero embedding for CLS token :return: Positional embeddings tensor """ if mode == "2d": if grid_size is None: raise ValueError("grid_size is required for 2d mode") return get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token) return get_1d_sincos_pos_embed(embed_dim, num_patches, cls_token)
[docs] def interpolate_pos_embed( pos_embed: torch.Tensor, src_size: tuple[int, int], tgt_size: tuple[int, int], num_prefix_tokens: int = 0, mode: str = "bicubic", ) -> torch.Tensor: """Interpolate positional embeddings to a new grid size. :param pos_embed: Original positional embeddings of shape (1, num_prefix + src_h*src_w, embed_dim) or (num_prefix + src_h*src_w, embed_dim) :param src_size: Source grid size as (height, width) :param tgt_size: Target grid size as (height, width) :param num_prefix_tokens: Number of prefix tokens (CLS, registers) to preserve :param mode: Interpolation mode ('nearest', 'bilinear', 'bicubic', 'area') :return: Interpolated positional embeddings Example:: old_pos = model.pos_embed # (1, 197, 768) = 1 + 14*14 new_pos = interpolate_pos_embed( old_pos, src_size=(14, 14), tgt_size=(16, 16), num_prefix_tokens=1 ) # (1, 257, 768) = 1 + 16*16 """ if pos_embed.dim() not in (2, 3): raise ValueError(f"pos_embed must be 2D or 3D, got {pos_embed.dim()}D") src_h, src_w = src_size tgt_h, tgt_w = tgt_size if src_h <= 0 or src_w <= 0 or tgt_h <= 0 or tgt_w <= 0: raise ValueError( f"All grid dims must be positive, src={src_size}, tgt={tgt_size}" ) squeeze_output = False if pos_embed.dim() == 2: pos_embed = pos_embed.unsqueeze(0) squeeze_output = True expected_src_len = num_prefix_tokens + src_h * src_w if pos_embed.shape[1] != expected_src_len: raise ValueError( f"pos_embed length {pos_embed.shape[1]} doesn't match expected {expected_src_len}" ) if src_h == tgt_h and src_w == tgt_w: return pos_embed.squeeze(0) if squeeze_output else pos_embed prefix_pos = pos_embed[:, :num_prefix_tokens, :] patch_pos = pos_embed[:, num_prefix_tokens:, :] embed_dim = patch_pos.shape[-1] patch_pos = patch_pos.reshape(1, src_h, src_w, embed_dim).permute(0, 3, 1, 2) patch_pos = F.interpolate( patch_pos, size=(tgt_h, tgt_w), mode=mode, align_corners=False if mode in ("bilinear", "bicubic") else None, ) patch_pos = patch_pos.permute(0, 2, 3, 1).reshape(1, tgt_h * tgt_w, embed_dim) result = torch.cat([prefix_pos, patch_pos], dim=1) return result.squeeze(0) if squeeze_output else result