stable_pretraining.backbone package#
Submodules#
stable_pretraining.backbone.aggregator module#
Modular tensor aggregation module for feeding multi-scale/multi-layer features to MLPs.
Commonly used for: - SSL linear probes using multiple transformer layers - Multi-scale feature fusion - Combining features from different network stages
- class stable_pretraining.backbone.aggregator.TensorAggregator(input_spec: str | List[str] | Dict[str, str], adaptive_pool_size: int = 1)[source]#
Bases:
ModuleAggregates multi-dimensional tensors into 2D format for MLP input.
Pure aggregation module with NO trainable parameters. Handles various input formats and aggregation strategies.
- Parameters:
input_spec – Specification of input format and aggregation modes: - str: Single aggregation mode for all tensors (e.g., “mean”) - List[str]: Per-tensor aggregation modes for list inputs - Dict[str, str]: Per-key aggregation modes for dict inputs
adaptive_pool_size – Output size for adaptive pooling (default: 1)
- Aggregation Modes:
“mean”: Spatial/temporal mean pooling
“max”: Spatial/temporal max pooling
“cls”: Take first token (for transformers with [CLS] token)
“flatten”: Flatten all dimensions after batch
“adaptive”: Adaptive average pooling to fixed size
Examples
>>> # Single tensor with mean pooling >>> agg = TensorAggregator("mean") >>> x = torch.randn(4, 768, 14, 14) >>> out = agg(x) # Shape: (4, 768)
>>> # SSL: Last 4 transformer layers with CLS token >>> agg = TensorAggregator(["cls", "cls", "cls", "cls"]) >>> layers = [torch.randn(4, 197, 768) for _ in range(4)] >>> out = agg(layers) # Shape: (4, 3072) # 768 * 4
>>> # Multi-scale features >>> agg = TensorAggregator({"layer1": "cls", "layer2": "mean", "conv": "mean"}) >>> out = agg( ... { ... "layer1": torch.randn(4, 197, 768), ... "layer2": torch.randn(4, 197, 768), ... "conv": torch.randn(4, 512, 14, 14), ... } ... ) # Shape: (4, 2048)
- compute_output_dim(input_shapes: tuple | List[tuple] | Dict[str, tuple]) int[source]#
Compute the output dimension given input shapes.
- Parameters:
input_shapes – Shape(s) of input tensor(s) (excluding batch dim)
- Returns:
Total output features
Examples
>>> agg = TensorAggregator(["cls", "mean"]) >>> agg.compute_output_dim([(197, 768), (197, 768)]) 1536
>>> agg = TensorAggregator({"l1": "cls", "conv": "mean"}) >>> agg.compute_output_dim({"l1": (197, 768), "conv": (512, 14, 14)}) 1280
stable_pretraining.backbone.convmixer module#
- class stable_pretraining.backbone.convmixer.ConvMixer(in_channels=3, num_classes=10, dim=64, depth=6, kernel_size=9, patch_size=7)[source]#
Bases:
ModuleConvMixer model.
A simple and efficient convolutional architecture that operates directly on patches.
- Parameters:
in_channels (int, optional) – Number of input channels. Defaults to 3.
num_classes (int, optional) – Number of output classes. Defaults to 10.
dim (int, optional) – Hidden dimension size. Defaults to 64.
depth (int, optional) – Number of ConvMixer blocks. Defaults to 6.
kernel_size (int, optional) – Kernel size for depthwise convolution. Defaults to 9.
patch_size (int, optional) – Patch embedding size. Defaults to 7.
Note
Introduced in [Trockman and Kolter, 2022].
- forward(xb)[source]#
Forward pass through the ConvMixer model.
- Parameters:
xb (torch.Tensor) – Input tensor of shape (batch_size, in_channels, height, width).
- Returns:
Output logits of shape (batch_size, num_classes).
- Return type:
stable_pretraining.backbone.mlp module#
- class stable_pretraining.backbone.mlp.MLP(in_channels: int, hidden_channels: list[int], norm_layer: str = None, activation_layer=<class 'torch.nn.modules.activation.ReLU'>, inplace: bool = None, bias: bool = True, dropout: float = 0.0)[source]#
Bases:
SequentialThis block implements the multi-layer perceptron (MLP) module.
- Parameters:
in_channels (int) – Number of channels of the input
hidden_channels (List[int]) – List of the hidden channel dimensions
norm_layer (Callable[..., torch.nn.Module], optional) – Norm layer that will be stacked on top of the linear layer. If
Nonethis layer won’t be used. Default:Noneactivation_layer (Callable[..., torch.nn.Module], optional) – Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the linear layer. If
Nonethis layer won’t be used. Default:torch.nn.ReLUinplace (bool, optional) – Parameter for the activation layer, which can optionally do the operation in-place. Default is
None, which uses the respective default values of theactivation_layerand Dropout layer.bias (bool) – Whether to use bias in the linear layer. Default
Truedropout (float) – The probability for the dropout layer. Default: 0.0
stable_pretraining.backbone.patch_masking module#
Patch masking strategies for masked image modeling.
- class stable_pretraining.backbone.patch_masking.IJEPAMaskOutput(context_idx: Tensor = None, target_idx: Tensor = None, target_block_masks: List[Tensor] = None, mask: Tensor = None)[source]#
Bases:
ModelOutputOutput from I-JEPA masking operation.
- Variables:
context_idx – Indices of context (visible) patches [B, N_ctx]
target_idx – Combined indices of all target patches [B, N_tgt]
target_block_masks – Per-block boolean masks [M x [B, N]], True = in this block
mask – Full mask where 1 = target, 0 = context [B, N]
- class stable_pretraining.backbone.patch_masking.IJEPAMasking(num_targets: int = 4, target_scale: Tuple[float, float] = (0.15, 0.2), target_aspect_ratio: Tuple[float, float] = (0.75, 1.5), context_scale: Tuple[float, float] = (0.85, 1.0), allow_target_overlap: bool = False)[source]#
Bases:
ModuleI-JEPA multi-block masking for joint-embedding predictive architecture.
Samples M non-overlapping target blocks and a context region that excludes all targets. This is the key masking strategy from I-JEPA [1]. Strategy:
Sample M target blocks with specified scale and aspect ratio
Context = all patches NOT in any target block
Optionally subsample context to specified ratio
- Parameters:
num_targets – Number of target blocks to sample (default: 4)
target_scale – (min, max) fraction of patches per target block
target_aspect_ratio – (min, max) aspect ratio of target blocks
context_scale – (min, max) fraction of non-target patches to keep as context
allow_target_overlap – Allow target blocks to overlap (default: False)
- Example::
- masking = IJEPAMasking(
num_targets=4, target_scale=(0.15, 0.2), target_aspect_ratio=(0.75, 1.5), context_scale=(0.85, 1.0),
)
# x: patch embeddings [B, N, D] output = masking(x, grid_h=14, grid_w=14)
- context_patches = x.gather(
1, output.context_idx.unsqueeze(-1).expand(-1, -1, D)
) target_patches = x.gather(1, output.target_idx.unsqueeze(-1).expand(-1, -1, D))
References
- extra_repr() str[source]#
Return the extra representation of the module.
To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.
- forward(x: Tensor, grid_h: int, grid_w: int) IJEPAMaskOutput[source]#
Apply I-JEPA masking.
- Parameters:
x – Patch embeddings [B, N, D] where N = grid_h * grid_w
grid_h – Height of patch grid
grid_w – Width of patch grid
- Returns:
IJEPAMaskOutput with context/target information
Note
Always returns exactly num_targets block masks. If overlap prevention makes it impossible to fit all blocks, some masks will be empty (all False). The combined target_idx only includes patches from non-empty blocks.
- class stable_pretraining.backbone.patch_masking.MaskingOutput(visible: Tensor = None, mask: Tensor = None, ids_restore: Tensor = None, ids_keep: Tensor = None)[source]#
Bases:
ModelOutputOutput from patch masking operation.
- Variables:
visible – Visible patch embeddings (B, N_keep, D)
mask – Binary mask where 1 = masked, 0 = visible (B, N)
ids_restore – Indices to restore original order (B, N)
ids_keep – Indices of kept (visible) patches (B, N_keep)
- class stable_pretraining.backbone.patch_masking.MultiBlockMasking(num_targets: int = 4, context_scale: Tuple[float, float] = (0.85, 1.0), target_scale: Tuple[float, float] = (0.15, 0.2), context_aspect_ratio: Tuple[float, float] = (1.0, 1.0), target_aspect_ratio: Tuple[float, float] = (0.75, 1.5))[source]#
Bases:
ModuleMulti-block masking for SALT Stage 1 (VPixel).
Generates one large context block and M target blocks using
multi_block_mask(), then makes context disjoint from targets. ReturnsMaskingOutputcompatible withMaskedEncoder.- Parameters:
num_targets – Number of target blocks (default: 4)
context_scale – (min, max) scale for context block (default: (0.85, 1.0))
target_scale – (min, max) scale for each target block (default: (0.15, 0.2))
context_aspect_ratio – (min, max) aspect ratio for context (default: (1.0, 1.0))
target_aspect_ratio – (min, max) aspect ratio for targets (default: (0.75, 1.5))
Example:
masking = MultiBlockMasking(num_targets=4) output = masking(patch_embeddings, grid_h=14, grid_w=14) visible_patches = output.visible # (B, N_keep, D) mask = output.mask # (B, N), 1=masked, 0=visible
- extra_repr() str[source]#
Return the extra representation of the module.
To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.
- forward(x: Tensor, grid_h: int, grid_w: int) MaskingOutput[source]#
Apply multi-block masking to patch embeddings.
- Parameters:
x – Patch embeddings [B, N, D] where N = grid_h * grid_w
grid_h – Height of patch grid
grid_w – Width of patch grid
- Returns:
MaskingOutput with visible patches and mask info
- class stable_pretraining.backbone.patch_masking.PatchMasking(mask_ratio: float = 0.75, block_size: int = 1, crop_ratio: float = 0.0, crop_aspect_ratio: tuple[float, float] = (0.75, 1.33))[source]#
Bases:
ModuleFlexible patch masking module for masked image modeling.
Supports three masking strategies that are selected stochastically:
Random: Uniformly random patch selection (when block_size=1)
Block: Square blocks of adjacent patches (when block_size > 1)
Crop: Rectangular crop region, remaining patches masked (when crop_ratio > 0)
Strategy selection per sample:
With probability
crop_ratio, use crop maskingOtherwise, if
block_size > 1, use block maskingOtherwise, use random masking
- Parameters:
mask_ratio – Fraction of patches to mask, in [0, 1)
block_size – Size of square blocks for block masking (1 = random masking)
crop_ratio – Probability of using crop masking vs block/random
crop_aspect_ratio – (min, max) aspect ratio range for crop regions
Example:
masking = PatchMasking(mask_ratio=0.75, block_size=4) output = masking(patch_embeddings, grid_h=14, grid_w=14) visible_patches = output.visible # (B, N_keep, D) mask = output.mask # (B, N), 1=masked, 0=visible ids_keep = output.ids_keep # (B, N_keep)
- extra_repr() str[source]#
Return the extra representation of the module.
To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.
- forward(x: Tensor, grid_h: int, grid_w: int) MaskingOutput[source]#
Apply masking to patch embeddings.
- Parameters:
x – Patch embeddings of shape (B, N, D) where N = grid_h * grid_w
grid_h – Height of the patch grid
grid_w – Width of the patch grid
- Returns:
MaskingOutput containing visible patches and mask information
- Raises:
ValueError – If x.shape[1] != grid_h * grid_w
ValueError – If input tensor has wrong number of dimensions
stable_pretraining.backbone.pos_embed module#
Positional embedding utilities for vision transformers.
- stable_pretraining.backbone.pos_embed.get_1d_sincos_pos_embed(embed_dim: int, length: int, cls_token: bool = False) Tensor[source]#
Generate 1D sinusoidal positional embeddings.
- Parameters:
embed_dim – Embedding dimension
length – Sequence length (number of positions)
cls_token – If True, prepend a zero embedding for CLS token
- Returns:
Positional embeddings of shape (length, embed_dim) or (length + 1, embed_dim) if cls_token=True
- stable_pretraining.backbone.pos_embed.get_2d_sincos_pos_embed(embed_dim: int, grid_size: int | tuple[int, int], cls_token: bool = False) Tensor[source]#
Generate 2D sinusoidal positional embeddings for image patches.
- Parameters:
embed_dim – Embedding dimension (must be divisible by 4)
grid_size – Grid height/width as int (square) or (height, width) tuple
cls_token – If True, prepend a zero embedding for CLS token
- Returns:
Positional embeddings of shape (H*W, embed_dim) or (H*W + 1, embed_dim) if cls_token=True
- stable_pretraining.backbone.pos_embed.get_sincos_pos_embed(embed_dim: int, num_patches: int, mode: Literal['1d', '2d'] = '1d', grid_size: int | tuple[int, int] | None = None, cls_token: bool = False) Tensor[source]#
Unified interface for generating sinusoidal positional embeddings.
- Parameters:
embed_dim – Embedding dimension
num_patches – Total number of patches (used for 1d mode)
mode – Embedding type - ‘1d’ for sequence, ‘2d’ for image grid
grid_size – Required for ‘2d’ mode
cls_token – If True, prepend a zero embedding for CLS token
- Returns:
Positional embeddings tensor
- stable_pretraining.backbone.pos_embed.get_timestep_embed(t: Tensor, dim: int, max_period: int = 10000) Tensor[source]#
Generate sinusoidal embeddings for continuous timesteps.
Unlike positional embeddings for sequences, this embeds scalar timestep values. Used for diffusion/flow matching time conditioning. :param t: Timestep values (B,) or (B, 1), typically in [0, 1] :param dim: Embedding dimension :param max_period: Maximum period for frequency scaling :return: Timestep embeddings of shape (B, dim)
- stable_pretraining.backbone.pos_embed.interpolate_pos_embed(pos_embed: Tensor, src_size: tuple[int, int], tgt_size: tuple[int, int], num_prefix_tokens: int = 0, mode: str = 'bicubic') Tensor[source]#
Interpolate positional embeddings to a new grid size.
- Parameters:
pos_embed – Original positional embeddings of shape (1, num_prefix + src_h*src_w, embed_dim) or (num_prefix + src_h*src_w, embed_dim)
src_size – Source grid size as (height, width)
tgt_size – Target grid size as (height, width)
num_prefix_tokens – Number of prefix tokens (CLS, registers) to preserve
mode – Interpolation mode (‘nearest’, ‘bilinear’, ‘bicubic’, ‘area’)
- Returns:
Interpolated positional embeddings
Example:
old_pos = model.pos_embed # (1, 197, 768) = 1 + 14*14 new_pos = interpolate_pos_embed( old_pos, src_size=(14, 14), tgt_size=(16, 16), num_prefix_tokens=1 ) # (1, 257, 768) = 1 + 16*16
stable_pretraining.backbone.probe module#
- class stable_pretraining.backbone.probe.AutoLinearClassifier(name, embedding_dim, num_classes, pooling=None, weight_decay=[0], lr_scaling=[1], normalization=['none', 'norm', 'bn'], dropout=[0, 0.5], label_smoothing=[0, 1])[source]#
Bases:
ModuleLinear using either CLS token or mean pooling with configurable normalization layer.
- Parameters:
embedding_dim (int) – Dimensionality of the input embeddings.
num_classes (int) – Number of output classes.
pooling (str) – Pooling strategy, either ‘cls’ or ‘mean’.
norm_layer (callable or None) – Normalization layer class (e.g., torch.nn.LayerNorm, torch.nn.BatchNorm1d), or None for no normalization. Should accept a single argument: normalized_shape or num_features.
- norm#
Instantiated normalization layer, or None.
- Type:
nn.Module or None
- fc#
Linear layer mapping pooled representation to class logits.
- Type:
nn.Linear
- Forward Args:
- x (torch.Tensor): Input tensor of shape (N, T, D) or (N, D).
If 3D, pooling and normalization are applied. If 2D, input is used directly (no pooling or normalization).
- Returns:
Output logits of shape (N, num_classes).
- Return type:
Example
>>> probe = LinearProbe( ... embedding_dim=128, ... num_classes=10, ... pooling="mean", ... norm_layer=torch.nn.LayerNorm, ... ) >>> x = torch.randn(32, 20, 128) >>> logits = probe(x) # shape: (32, 10) >>> x2 = torch.randn(32, 128) >>> logits2 = probe(x2) # shape: (32, 10)
- forward(x, y=None, pl_module=None)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class stable_pretraining.backbone.probe.AutoTuneMLP(in_features: int, out_features: int, hidden_features: List[int] | List[List[int]], name: str, loss_fn: Callable, additional_weight_decay: float | List[float] = [0], lr_scaling: float | List[float] = [1], normalization: str | List[str] = ['none'], dropout: float | List[float] = [0], activation: str | List[str] = ['relu'])[source]#
Bases:
ModuleAutomatically creates multiple MLP variants with different hyperparameter combinations.
This module creates a grid of MLPs with different configurations (dropout, normalization, learning rates, architectures, etc.) to enable parallel hyperparameter tuning.
- Parameters:
in_features – Number of input features
out_features – Number of output features
hidden_features – Architecture specification. Can be: - List[int]: Single architecture, e.g., [256, 128] - List[List[int]]: Multiple architectures, e.g., [[256, 128], [512, 256, 128]] - []: Empty list for linear model (no hidden layers)
name – Base name for this AutoTuneMLP instance
loss_fn – Loss function to compute loss
additional_weight_decay – List of weight decay values to try
lr_scaling – List of learning rate scaling factors to try
normalization – List of normalization types [‘none’, ‘norm’, ‘bn’]
dropout – List of dropout rates to try
activation – List of activation functions [‘relu’, ‘leaky_relu’, ‘tanh’]
Examples
>>> # Single architecture >>> model = AutoTuneMLP(128, 10, [256, 128], "clf", nn.CrossEntropyLoss())
>>> # Multiple architectures >>> model = AutoTuneMLP( ... 128, 10, [[256], [256, 128], [512, 256]], "clf", nn.CrossEntropyLoss() ... )
>>> # Linear model (no hidden layers) >>> model = AutoTuneMLP(128, 10, [], "linear_clf", nn.CrossEntropyLoss())
- forward(x: Tensor, y: Tensor | None = None) Dict[str, Tensor][source]#
Forward pass through all MLP variants.
- Parameters:
x – Input tensor of shape (batch_size, in_features)
y – Optional target tensor for loss computation
- Returns:
Dictionary with predictions and losses for each variant Format: {‘pred/{variant_id}’: tensor, ‘loss/{variant_id}’: tensor}
- get_best_variant(metric_dict: Dict[str, float], lower_is_better: bool = True) str[source]#
Get the best performing variant based on metrics.
- Parameters:
metric_dict – Dictionary mapping variant_id to metric values
lower_is_better – If True, lower metric is better (e.g., loss). If False, higher is better (e.g., accuracy)
- Returns:
ID of the best performing variant
- get_variant(key: str) Module[source]#
Get a specific MLP variant by key.
- Parameters:
key – Variant ID
- Returns:
The MLP module
- Raises:
KeyError – If key doesn’t exist
- class stable_pretraining.backbone.probe.LinearProbe(embedding_dim, num_classes, pooling='cls', norm_layer=None)[source]#
Bases:
ModuleLinear using either CLS token or mean pooling with configurable normalization layer.
- Parameters:
embedding_dim (int) – Dimensionality of the input embeddings.
num_classes (int) – Number of output classes.
pooling (str) – Pooling strategy, either ‘cls’ or ‘mean’.
norm_layer (callable or None) – Normalization layer class (e.g., torch.nn.LayerNorm, torch.nn.BatchNorm1d), or None for no normalization. Should accept a single argument: normalized_shape or num_features.
- norm#
Instantiated normalization layer, or None.
- Type:
nn.Module or None
- fc#
Linear layer mapping pooled representation to class logits.
- Type:
nn.Linear
- Forward Args:
- x (torch.Tensor): Input tensor of shape (N, T, D) or (N, D).
If 3D, pooling and normalization are applied. If 2D, input is used directly (no pooling or normalization).
- Returns:
Output logits of shape (N, num_classes).
- Return type:
Example
>>> probe = LinearProbe( ... embedding_dim=128, ... num_classes=10, ... pooling="mean", ... norm_layer=torch.nn.LayerNorm, ... ) >>> x = torch.randn(32, 20, 128) >>> logits = probe(x) # shape: (32, 10) >>> x2 = torch.randn(32, 128) >>> logits2 = probe(x2) # shape: (32, 10)
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class stable_pretraining.backbone.probe.MultiHeadAttentiveProbe(embedding_dim: int, num_classes: int, num_heads: int = 4)[source]#
Bases:
ModuleA multi-head attentive probe for sequence representations.
This module applies multiple attention heads to a sequence of embeddings, pools the sequence into a fixed-size representation per head, concatenates the results, and projects to a set of output classes.
- Parameters:
- ln#
Layer normalization applied to the input.
- Type:
- attn_vectors#
Learnable attention vectors for each head, shape (num_heads, embedding_dim).
- Type:
torch.nn.Parameter
- fc#
Final linear layer mapping concatenated head outputs to class logits.
- Type:
- Forward Args:
- x (torch.Tensor): Input tensor of shape (N, T, D), where
N = batch size, T = sequence length, D = embedding_dim.
- Returns:
Output logits of shape (N, num_classes).
- Return type:
Example
>>> probe = MultiHeadAttentiveProbe( ... embedding_dim=128, num_classes=10, num_heads=4 ... ) >>> x = torch.randn(32, 20, 128) # batch of 32, sequence length 20 >>> logits = probe(x) # shape: (32, 10)
- forward(x: Tensor)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
stable_pretraining.backbone.resnet9 module#
- class stable_pretraining.backbone.resnet9.MLP(in_channels: int, hidden_channels: list[int], norm_layer: str = None, activation_layer=<class 'torch.nn.modules.activation.ReLU'>, inplace: bool = None, bias: bool = True, dropout: float = 0.0)[source]#
Bases:
SequentialThis block implements the multi-layer perceptron (MLP) module.
- Parameters:
in_channels (int) – Number of channels of the input
hidden_channels (List[int]) – List of the hidden channel dimensions
norm_layer (Callable[..., torch.nn.Module], optional) – Norm layer that will be stacked on top of the linear layer. If
Nonethis layer won’t be used. Default:Noneactivation_layer (Callable[..., torch.nn.Module], optional) – Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the linear layer. If
Nonethis layer won’t be used. Default:torch.nn.ReLUinplace (bool, optional) – Parameter for the activation layer, which can optionally do the operation in-place. Default is
None, which uses the respective default values of theactivation_layerand Dropout layer.bias (bool) – Whether to use bias in the linear layer. Default
Truedropout (float) – The probability for the dropout layer. Default: 0.0
- class stable_pretraining.backbone.resnet9.ResidualBlock(in_channels, out_channels, kernel_size, padding, stride)[source]#
Bases:
ModuleA residual block as defined by He et al.
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class stable_pretraining.backbone.resnet9.Resnet9(num_classes, num_channels, *args, **kwargs)[source]#
Bases:
ModuleA Residual network.
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
stable_pretraining.backbone.utils module#
- class stable_pretraining.backbone.utils.EfficientMaskedTimmViT(vit: Module)[source]#
Bases:
ModuleOptimized Vision Transformer wrapper that efficiently handles NaN patches.
This module is designed to work with timm ViT models and provides: - Per-sample NaN masking (different NaN patterns per image in batch) - Fast path for same masking pattern across batch - Support for class tokens (cls_token), distillation tokens (dist_token), and register tokens - Compatibility with various timm ViT architectures (vit_*, deit_*, beit_*, etc.) - Minimal overhead when no masking is present
Key Optimizations: - Early exit when no NaN patches detected - Simpler indexing for same masking patterns - Cached batch indices for repeated operations - Zero-copy operations where possible
- Parameters:
vit – A timm Vision Transformer model instance
- Raises:
ValueError – If samples have different numbers of NaN patches
ValueError – If all patches are NaN
RuntimeError – If the model structure is incompatible
Example
>>> import timm >>> vit = timm.create_model( ... "vit_base_patch16_224", pretrained=False, reg_tokens=4 ... ) >>> masked_vit = EfficientMaskedTimmViT(vit) >>> >>> # Create input with some NaN patches >>> x = torch.randn(4, 3, 224, 224) >>> output = masked_vit(x)
- Performance:
Same pattern masking: ~0-5% overhead vs different patterns
No masking: <2% overhead vs original model
50% masking: ~1.5x speedup
90% masking: ~2.5-3x speedup
Note
All samples in a batch must have the same NUMBER of NaN patches, but the LOCATION of NaN patches can differ per sample.
Register tokens (DINOv2 style) do NOT receive positional embeddings.
- clear_cache()[source]#
Clear the cached batch indices.
Useful if you want to free memory after processing different batch sizes. The cache will be rebuilt as needed during forward passes.
- forward(x: Tensor) Tensor[source]#
Forward pass through the masked ViT.
This method implements an optimized forward pass with the following features: - Early exit for inputs without NaN patches (fast path) - Optimized indexing for same masking patterns across batch - Per-sample masking support with advanced indexing - Automatic NaN replacement for partial NaN patches - Support for register tokens (DINOv2 style)
- Parameters:
x – Input tensor, either:
images (- Raw) – shape (B, C, H, W)
Pre-patchified (-) – shape (B, N, D) where N is number of patches
- Returns:
Model output (logits if head exists, features otherwise)
- Return type:
- Raises:
ValueError – If samples have different numbers of NaN patches
ValueError – If all patches are NaN
- Performance Notes:
No NaN patches: Uses fast path with <2% overhead
Same pattern: Optimized indexing, ~0-5% overhead vs different patterns
Different patterns: Uses advanced indexing, ~10-35% slower at high masking
- class stable_pretraining.backbone.utils.EmbeddingOutput(last_hidden_state: Any = None, hidden_states: dict[str, Tensor] = None)[source]#
Bases:
ModelOutputHuggingFace-style output container for model embeddings.
The final output from the backbone model.
- Type:
Any
Dictionary mapping layer names to their intermediate outputs.
- Type:
- class stable_pretraining.backbone.utils.EvalOnly(backbone: Module)[source]#
Bases:
ModuleWrapper that forces a module to remain in evaluation mode.
- forward(*args, **kwargs)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class stable_pretraining.backbone.utils.FeaturesConcat(agg: callable, names: str | Iterable[str] = None)[source]#
Bases:
ModuleAggregates and concatenates features from a dictionary input, then classifies.
- Parameters:
names (List[str]) – Keys to extract from the input dictionary. if not given then we aggregate everything from dict/list
- forward(inputs: dict | Iterable)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class stable_pretraining.backbone.utils.HiddenStateExtractor(backbone: Module, module_names: list[str])[source]#
Bases:
ModuleWrapper that captures intermediate embeddings from specified layers.
Returns outputs in HuggingFace Transformers style with last_hidden_state for the final backbone output and hidden_states for intermediate layers.
- Parameters:
backbone – The neural network module to wrap.
module_names – List of module names to capture (e.g., [‘layer1’, ‘encoder.block1’]). Supports nested modules using dot notation.
- Returns:
last_hidden_state: Final backbone output
hidden_states: Dict mapping module names to their outputs
- Return type:
EmbeddingOutput with
Example
>>> model = ReturnEmbedding( ... torchvision.models.swin_v2_s(), ... ["features.0", "features.2", "features.4"], ... ) >>> output = model(images) >>> output.last_hidden_state # final output >>> output.hidden_states["features.2"] # intermediate layer
- Raises:
ValueError – If any module name is not found in the backbone.
- forward(*args, **kwargs) EmbeddingOutput[source]#
Run forward pass and return embeddings in HuggingFace style.
- class stable_pretraining.backbone.utils.TeacherStudentWrapper(student: Module, warm_init: bool = True, base_ema_coefficient: float = 0.996, final_ema_coefficient: float = 1)[source]#
Bases:
ModuleBackbone wrapper that implements teacher-student distillation via EMA.
This is a wrapper for backbones that creates a teacher model as an exponential moving average (EMA) of the student model. It should be passed as the backbone to stable_pretraining.Module and accessed via forward_student() and forward_teacher() methods in your custom forward function.
The teacher model is updated by taking a running average of the student’s parameters and buffers. When ema_coefficient == 0.0, the teacher and student are literally the same object, saving memory but forward passes through the teacher will not produce any gradients.
- Usage example:
backbone = ResNet18() wrapped_backbone = TeacherStudentWrapper(backbone) module = ssl.Module(
backbone=wrapped_backbone, projector=projector, forward=forward_with_teacher_student, …
)
- Parameters:
student (torch.nn.Module) – The student model whose parameters will be tracked.
warm_init (bool, optional) – If True, performs an initialization step to match the student’s parameters immediately. Default is True.
base_ema_coefficient (float, optional) – EMA decay factor at the start of training. This value will be updated following a cosine schedule. Should be in [0, 1]. A value of 0.0 means the teacher is fully updated to the student’s parameters on every step, while a value of 1.0 means the teacher remains unchanged. Default is 0.996.
final_ema_coefficient (float, optional) – EMA decay factor at the end of training. Default is 1.
- forward(*args, **kwargs)[source]#
Forward pass through either the student or teacher network.
You can choose which model to run in the default forward. Commonly the teacher is evaluated, so we default to that.
- forward_student(*args, **kwargs)[source]#
Forward pass through the student network. Gradients will flow normally.
- forward_teacher(*args, **kwargs)[source]#
Forward pass through the teacher network.
By default, the teacher network does not require grad. If ema_coefficient == 0, then teacher==student, so we wrap in torch.no_grad() to ensure no gradients flow.
- update_ema_coefficient(epoch: int, total_epochs: int)[source]#
Update the EMA coefficient following a cosine schedule.
- The EMA coefficient is updated following a cosine schedule:
ema_coefficient = final_ema_coefficient - 0.5 * (final_ema_coefficient - base_ema_coefficient) * (1 + cos(epoch / total_epochs * pi))
- update_teacher()[source]#
Perform one EMA update step on the teacher’s parameters.
- The update rule is:
teacher_param = ema_coefficient * teacher_param + (1 - ema_coefficient) * student_param
This is done in a no_grad context to ensure the teacher’s parameters do not accumulate gradients, but the student remains fully trainable.
Everything is updated, including buffers (e.g. batch norm running averages).
- stable_pretraining.backbone.utils.from_huggingface(model_name, pretrained, attn_implementation='sdpa', **kwargs)[source]#
Loads a Hugging Face Transformers base model, optionally with pretrained weights, and returns the backbone model.
This function wraps the Hugging Face transformers library to load a model specified by model_name. It supports loading either pretrained weights or initializing from configuration only. The returned object is the model’s backbone (model.base_model), which is useful for extracting the core architecture without task-specific heads.
- Parameters:
model_name (str) – The Hugging Face model repository identifier or local path. Examples include “bert-base-uncased”, “facebook/opt-1.3b”, or a local directory containing model files.
pretrained (bool) – If True, loads pretrained weights via AutoModel.from_pretrained. If False, initializes the model from configuration only via AutoConfig.from_pretrained and AutoModel.from_config.
attn_implementation (str, optional) – The attention backend to use. Supported values include “sdpa” (default), “eager”, “flash_attention_2”, etc., as supported by the installed version of transformers and your hardware. This is forwarded to the underlying model constructor.
**kwargs – Additional keyword arguments forwarded to AutoModel.from_pretrained or AutoConfig.from_pretrained. Common options include: - revision (str): Model version or branch to use. - cache_dir (str): Directory to cache downloaded models. - trust_remote_code (bool): Allow loading custom code from model repo. - torch_dtype (str or torch.dtype): Data type for model weights. - device_map (str or dict): Device placement for model parameters. - And others supported by Hugging Face Transformers.
- Returns:
The base (backbone) model instance, typically accessible via model.base_model. For some architectures, this may be the model itself.
- Return type:
transformers.PreTrainedModel
- Raises:
ImportError – If the transformers library is not installed.
OSError – If the model or configuration cannot be found or downloaded.
ValueError – If invalid arguments are provided.
Exception – Propagates any other exceptions raised by Hugging Face Transformers.
Notes
The returned base_model may differ depending on the architecture. For some models, base_model is the same as the full model.
The availability of certain attention implementations (e.g., “flash_attention_2”) depends on your hardware, installed libraries, and the version of transformers.
Ensure that your environment meets the requirements for the selected attention backend.
Examples
>>> # Load a pretrained BERT model with default attention >>> model = from_huggingface("bert-base-uncased", pretrained=True) >>> # Initialize a model from config only, specifying a revision and device >>> model = from_huggingface( ... "facebook/opt-1.3b", ... pretrained=False, ... revision="main", ... device_map="auto", ... ) >>> # Load a pretrained model using flash attention (if supported) >>> model = from_huggingface( ... "meta-llama/Llama-2-7b-hf", ... pretrained=True, ... attn_implementation="flash_attention_2", ... )
- stable_pretraining.backbone.utils.from_torchvision(model_name, low_resolution=False, **kwargs)[source]#
Load a backbone model.
If num_classes is provided, the last layer is replaced by a linear layer of output size num_classes. Otherwise, the last layer is replaced by an identity layer.
- Parameters:
model_name (str) – Name of the backbone model. Supported models are: - Any model from torchvision.models - “Resnet9” - “ConvMixer”
low_resolution (bool, optional) – Whether to adapt the resolution of the model (for CIFAR typically). By default False.
**kwargs –
Additional keyword arguments for the model. Special handling: - in_channels (int): Number of input channels. If provided for ResNet models, the first
conv layer will be modified to accept this many channels. Default is 3.
- Returns:
The neural network model.
- Return type:
- stable_pretraining.backbone.utils.get_children_modules(model: Module, parent_name: str, L: int = 1, partial_match: bool = False) List[str][source]#
Extracts unique module names matching a given parent_name and L submodules.
- Parameters:
model – The root nn.Module.
parent_name – The string or path component to match (e.g., ‘blocks’).
L – Number of levels after the parent_name to include in the result.
partial_match – whether to check with == or in
- Returns:
Sorted list of unique qualified module names at depth L after the parent_name.
- stable_pretraining.backbone.utils.get_output_shape(model: Module, *inputs, **kwargs) Any[source]#
Infers the output shapes of a PyTorch nn.Module by forwarding fake inputs on the ‘meta’ device using FakeTensorMode.
Handles arbitrary nested output structures (lists, dicts, tuples, sets, namedtuples, dataclasses), preserving their structure but replacing torch.Tensor objects with their .shape. This function temporarily replaces the model’s parameters and buffers with fake tensors on the ‘meta’ device, converts all tensor inputs and keyword arguments to ‘meta’, and runs the forward pass under FakeTensorMode. After execution, the original parameters and buffers are restored. No real computation or memory allocation occurs.
- Parameters:
model (torch.nn.Module) – The PyTorch module to evaluate. Must be on a real device (e.g., CPU).
*inputs – Positional arguments to pass to the model’s forward method. All torch.Tensor inputs are converted to ‘meta’.
**kwargs – Keyword arguments to pass to the model’s forward method. All torch.Tensor values are converted to ‘meta’.
- Returns:
- The output structure from the model’s forward pass, with all torch.Tensor objects replaced by their .shape.
Non-tensor objects are left unchanged.
- Return type:
Any
Notes
Supports nested output structures: dict, list, tuple, set, namedtuple, and dataclasses.
No real memory is allocated; all tensors are on the ‘meta’ device.
Not thread-safe: concurrent calls may interfere with parameter/buffer swapping.
Requires PyTorch 1.11+ for FakeTensorMode.
If the model contains custom buffers or state, ensure they are handled appropriately.
Raises exceptions if model forward fails or if parameters/buffers cannot be swapped.
Non-tensor outputs are returned unchanged.
Example
shapes = get_output_shape_multi_input(model, input1, input2, key1=kwarg1) # shapes will have the same structure as the model’s output, but with torch.Size in place of tensors.
- stable_pretraining.backbone.utils.register_lr_scale_hook(module, lr_scale, weight_decay=0.0)[source]#
Registers a hook that scales gradients and applies weight decay during backward pass.
- Parameters:
module – PyTorch module/layer
lr_scale – Scaling factor for the learning rate (scales gradients)
weight_decay – L2 penalty coefficient (default: 0.0)
- Returns:
The same module (for chaining)
- Return type:
- stable_pretraining.backbone.utils.set_embedding_dim(module, dim, bias=True, expected_input_shape: tuple | list | None = None, expected_output_shape: tuple | list | None = None)[source]#
- stable_pretraining.backbone.utils.vit_hf(size: str = 'tiny', patch_size: int = 16, image_size: int = 224, pretrained: bool = False, use_mask_token: bool = True, **kwargs) Module[source]#
Create a Vision Transformer using HuggingFace transformers.
This provides a clean, well-maintained ViT implementation with native support for: - Masking via bool_masked_pos parameter - Learnable mask token - Easy access to CLS and patch tokens
- Parameters:
size – Model size - “tiny”, “small”, “base”, or “large”
patch_size – Patch size (default: 16)
image_size – Input image size (default: 224)
pretrained – Load pretrained weights from HuggingFace Hub
use_mask_token – Whether to include learnable mask token (needed for iBOT)
**kwargs – Additional ViTConfig parameters
- Returns:
HuggingFace ViTModel
Example
>>> backbone = vit_hf("tiny", use_mask_token=True) >>> x = torch.randn(2, 3, 224, 224) >>> >>> # Without masking >>> output = backbone(x) >>> cls_token = output.last_hidden_state[:, 0, :] >>> patch_tokens = output.last_hidden_state[:, 1:, :] >>> >>> # With masking (for iBOT student) >>> masks = torch.zeros(2, 196, dtype=torch.bool) >>> masks[:, :59] = True # Mask 30% >>> output = backbone(x, bool_masked_pos=masks)
stable_pretraining.backbone.vit module#
- class stable_pretraining.backbone.vit.Attention(dim: int, num_heads: int = 8, qkv_bias: bool = True, attn_drop: float = 0.0, proj_drop: float = 0.0, use_rope: bool = False, use_qk_norm: bool = False, max_grid_size: int = 32)[source]#
Bases:
ModuleMulti-head self-attention with efficient SDPA backend.
Supports modern transformer features including Rotary Position Embeddings (RoPE) and Query-Key Normalization (QK-Norm) for improved training stability and positional generalization.
Uses
F.scaled_dot_product_attentionwhich automatically selects the optimal backend:Flash Attention (when available, fastest)
Memory-efficient attention (xformers-style)
Math fallback
Architecture Features#
- RoPE (Rotary Position Embedding):
Encodes relative 2D positions via complex rotations applied to Q and K. Unlike additive positional embeddings, RoPE:
Naturally captures relative positions
Generalizes to unseen sequence lengths
Requires no extra parameters
Enable with
use_rope=True. Requiresgrid_sizein forward().- QK-Norm (Query-Key Normalization):
Applies LayerNorm (without learnable params) to Q and K before computing attention scores. Benefits:
Prevents attention logit explosion in deep networks
Stabilizes training without extra hyperparameter tuning
Essential when combined with SwiGLU/LayerScale
Enable with
use_qk_norm=True.
Attention Masking#
Supports flexible attention patterns via
attn_mask:Causal (autoregressive):
torch.triu(ones, diagonal=1)Bidirectional:
attn_mask=NoneBlock sparse: Custom boolean masks
Leave-one-out:
torch.eye(N)(each token ignores itself)
Mask convention:
True= blocked (cannot attend),False= allowed.- param dim:
Input/output embedding dimension
- param num_heads:
Number of parallel attention heads. Must divide
dim.- param qkv_bias:
If True, add learnable bias to Q, K, V projections. Default True following ViT convention.
- param attn_drop:
Dropout probability on attention weights. Applied only during training.
- param proj_drop:
Dropout probability on output projection.
- param use_rope:
Enable 2D Rotary Position Embedding. When True, position information is encoded via rotation in attention rather than additive embeddings. Requires
grid_sizeparameter in forward().- param use_qk_norm:
Enable Query-Key normalization. Applies LayerNorm (without learnable parameters) to Q and K tensors before attention. Recommended for deep networks or when using SwiGLU.
- param max_grid_size:
Maximum spatial grid size for RoPE frequency cache. Only used when
use_rope=True. Set to largest expected grid dimension.
Example:
# Standard attention attn = Attention(dim=768, num_heads=12) out = attn(x) # [B, N, 768] # With RoPE for vision (requires grid_size) attn = Attention(dim=768, num_heads=12, use_rope=True) out = attn(x, grid_size=(14, 14)) # With QK-Norm for training stability attn = Attention(dim=768, num_heads=12, use_qk_norm=True) out = attn(x) # Causal attention (autoregressive) N = x.shape[1] causal_mask = torch.triu(torch.ones(N, N, dtype=torch.bool), diagonal=1) out = attn(x, attn_mask=causal_mask) # NEPA-style: RoPE + QK-Norm + causal attn = Attention(dim=768, num_heads=12, use_rope=True, use_qk_norm=True) causal_mask = torch.triu( torch.ones(N, N, dtype=torch.bool, device=x.device), diagonal=1 ) out = attn(x, attn_mask=causal_mask, grid_size=(14, 14))
Note
When
use_rope=True, do NOT add positional embeddings to input tokens. RoPE encodes positions internally via Q/K rotation.References
RoPE: Su et al., “RoFormer: Enhanced Transformer with Rotary Position Embedding” (2021)
QK-Norm: Henry et al., “Query-Key Normalization for Transformers” (2020)
Flash Attention: Dao et al., “FlashAttention: Fast and Memory-Efficient Exact Attention” (2022)
- extra_repr() str[source]#
Return the extra representation of the module.
To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.
- forward(x: Tensor, attn_mask: Tensor | None = None, grid_size: Tuple[int, int] | None = None) Tensor[source]#
Compute multi-head self-attention.
- Parameters:
x – Input tensor of shape
[B, N, D]where B is batch size, N is sequence length, and D is embedding dimension.attn_mask –
Optional attention mask. Supported shapes:
[N, N]: Same mask for all batches and heads[B, N, N]: Per-batch mask, broadcast over heads[B, H, N, N]: Full per-batch, per-head mask
Values:
True= blocked (cannot attend),False= allowed.Common patterns:
Causal:
torch.triu(torch.ones(N, N, dtype=torch.bool), diagonal=1)Leave-one-out:
torch.eye(N, dtype=torch.bool)
grid_size – Spatial grid dimensions as
(height, width). Required whenuse_rope=True. Used to compute 2D rotary position embeddings. For a 224x224 image with patch_size=16, usegrid_size=(14, 14).
- Returns:
Output tensor of shape
[B, N, D]- Raises:
ValueError – If
use_rope=Truebutgrid_sizeis None
- class stable_pretraining.backbone.vit.CrossAttention(dim: int, context_dim: int | None = None, num_heads: int = 8, qkv_bias: bool = True, attn_drop: float = 0.0, proj_drop: float = 0.0)[source]#
Bases:
ModuleMulti-head cross-attention with efficient SDPA backend.
Queries attend to key-value pairs from a separate context sequence. Supports attention masking to block specific query-key interactions.
- Parameters:
dim – Query dimension
context_dim – Context dimension (defaults to dim)
num_heads – Number of attention heads
qkv_bias – Add bias to projections
attn_drop – Attention dropout rate
proj_drop – Output projection dropout rate
Example:
cross_attn = CrossAttention(dim=768, context_dim=1024, num_heads=12) # Standard cross-attention out = cross_attn(queries, context) # [B, N, 768] # Masked cross-attention: block certain query-context pairs # mask[i, j] = True means query i cannot attend to context j mask = torch.zeros(N, M, dtype=torch.bool) mask[:, :10] = True # block attention to first 10 context tokens out = cross_attn(queries, context, attn_mask=mask)
- forward(x: Tensor, context: Tensor, attn_mask: Tensor | None = None) Tensor[source]#
Forward pass.
- Parameters:
x – Query tensor [B, N, D]
context – Key-value tensor [B, M, context_dim]
attn_mask – Cross-attention mask. Can be one of: - [N, M]: Same mask for all batches and heads - [B, N, M]: Per-batch mask, broadcast over heads - [B, H, N, M]: Full per-batch, per-head mask Mask values: True = blocked (cannot attend), False = allowed.
- Returns:
Output tensor [B, N, D]
- class stable_pretraining.backbone.vit.FlexibleTransformer(input_dim: int = 768, hidden_dim: int = 384, output_dim: int = 768, num_patches: int = 196, depth: int = 4, num_heads: int = 6, mlp_ratio: float = 4.0, self_attn: bool = True, cross_attn: bool = True, use_adaln: bool = True, pos_embed_type: Literal['sincos_1d', 'sincos_2d', 'learned'] = 'sincos_2d', grid_size: int | tuple[int, int] | None = None, drop_path_rate: float = 0.0, attn_drop: float = 0.0, proj_drop: float = 0.0, zero_init_output: bool = True, num_prefix_tokens: int = 1, num_registers: int = 0, add_mask_token: bool = False)[source]#
Bases:
ModuleFlexible transformer supporting multiple architectures.
Unified backbone for: - MAE decoder: self_attn=True, cross_attn=False, use_adaln=False - IJEPA predictor: self_attn=True, cross_attn=True, use_adaln=False - DiT / Flow: self_attn=True, cross_attn=True/False, use_adaln=True - MaskGIT: self_attn=True, cross_attn=False, use_adaln=True, add_mask_token=True - Lightweight predictor: self_attn=True, cross_attn=False, use_adaln=False, num_registers>0 - Leave-one-out prediction: self_attn=True, cross_attn=False with diagonal attn_mask
- Parameters:
input_dim – Input embedding dimension (from encoder)
hidden_dim – Internal transformer dimension
output_dim – Output dimension
num_patches – Total number of patches (for positional embeddings)
depth – Number of transformer blocks
num_heads – Number of attention heads
mlp_ratio – MLP hidden dim multiplier
self_attn – Enable self-attention in blocks
cross_attn – Enable cross-attention in blocks
use_adaln – Enable AdaLN-Zero conditioning
pos_embed_type – ‘sincos_1d’, ‘sincos_2d’, or ‘learned’
grid_size – Grid size for 2D positional embeddings
drop_path_rate – Stochastic depth rate (linearly increases through layers)
attn_drop – Attention dropout rate
proj_drop – Projection dropout rate
zero_init_output – Zero-initialize output projection
num_prefix_tokens – Number of prefix tokens (e.g., CLS token) expected in input. These are tokens whose content comes from the encoder but need special positional embeddings.
num_registers – Number of learnable register tokens to prepend internally. Unlike prefix tokens, registers are fully learnable (both content and position) and are prepended automatically—callers don’t include them in input.
add_mask_token – Enable learnable [MASK] token for masked prediction. When enabled, use context_mask and/or query_mask in forward() to replace tokens at specified positions with the [MASK] token.
Example:
# MAE decoder decoder = FlexibleTransformer( 768, 512, 768, 196, depth=8, self_attn=True, cross_attn=False, use_adaln=False, ) out = decoder(context, queries, context_idx, query_idx) # IJEPA predictor predictor = FlexibleTransformer( 768, 384, 768, 196, depth=6, self_attn=True, cross_attn=False, add_mask_token=True, use_adaln=False, ) out = predictor(context, queries, context_idx, query_idx) # DiT-style flow matching flow = FlexibleTransformer( 768, 384, 768, 196, depth=12, self_attn=True, cross_attn=False, use_adaln=True, ) out = flow(context, queries, context_idx, query_idx, t=timesteps) # MaskGIT-style: variable number of masks per sample maskgit = FlexibleTransformer( 768, 512, 768, 196, depth=8, self_attn=True, cross_attn=False, use_adaln=True, add_mask_token=True, ) context_mask = torch.rand(B, num_patches) < mask_ratio out = maskgit( context=all_patches, queries=all_patches[:, :0], context_idx=torch.arange(196).expand(B, -1), query_idx=torch.empty(B, 0, dtype=torch.long), context_mask=context_mask, t=timesteps, return_all=True, ) # Leave-one-out prediction: each token predicted from all others predictor = FlexibleTransformer( 768, 384, 768, 196, depth=4, self_attn=True, cross_attn=False, use_adaln=False, ) # Diagonal mask: each token cannot attend to itself T = x.shape[1] attn_mask = torch.eye(T, dtype=torch.bool, device=x.device) out = predictor( context=x, queries=x[:, :0], # empty queries context_idx=torch.arange(T).expand(B, -1), query_idx=torch.empty(B, 0, dtype=torch.long), attn_mask=attn_mask, # [T, T] bool, True = blocked return_all=True, ) # out[:, t] is predicted from x[:, ≠t] # Lightweight predictor with register tokens predictor = FlexibleTransformer( 768, 384, 768, 196, depth=4, num_heads=6, self_attn=True, cross_attn=False, use_adaln=False, num_registers=4, num_prefix_tokens=0, ) out, registers = predictor( context=encoder_output, queries=encoder_output[:, :0], context_idx=ids_keep, query_idx=torch.empty(B, 0, dtype=torch.long), return_all=True, return_registers=True, )
- forward(context: Tensor, queries: Tensor = None, context_idx: Tensor = None, query_idx: Tensor = None, t: Tensor | None = None, num_prefix: int | None = None, return_all: bool = False, return_registers: bool = False, context_mask: Tensor | None = None, query_mask: Tensor | None = None, attn_mask: Tensor | None = None) Tensor | tuple[Tensor, Tensor][source]#
Forward pass.
- Parameters:
context – Context token embeddings [B, N_ctx, input_dim]
queries – Query token embeddings [B, N_qry, input_dim]
context_idx – Patch indices for context tokens [B, N_ctx]
query_idx – Patch indices for query tokens [B, N_qry]
t – Timesteps for conditioning [B] (required if use_adaln=True)
num_prefix – Override for number of prefix tokens in context
return_all – If True and using joint attention (cross_attn=False), return all tokens unshuffled to original position order. Output shape: [B, N_ctx + N_qry, output_dim]. Ignored for cross-attention modes.
return_registers – If True and num_registers > 0, also return register token outputs as a second tensor. Returns tuple of (main_output, register_output) where register_output is [B, num_registers, output_dim].
context_mask – Boolean mask indicating which context tokens to replace with [MASK] token [B, N_ctx]. True = replace with mask. Each sample can have a different number of True values. Requires add_mask_token=True.
query_mask – Boolean mask indicating which query tokens to replace with [MASK] token [B, N_qry]. True = replace with mask. Each sample can have a different number of True values. Requires add_mask_token=True.
attn_mask – Attention mask for self-attention [T, T] or [B, T, T]. True = blocked (cannot attend), False = allowed. For leave-one-out prediction, use torch.eye(T, dtype=torch.bool). Only applies to joint attention mode (cross_attn=False). The mask is automatically expanded to account for registers.
- Returns:
Output embeddings. Shape depends on mode: - cross_attn=True: [B, N_qry, output_dim] - cross_attn=False, return_all=False: [B, N_qry, output_dim] - cross_attn=False, return_all=True: [B, N_ctx + N_qry, output_dim] If return_registers=True, returns tuple (output, registers) where registers is [B, num_registers, output_dim].
- class stable_pretraining.backbone.vit.MAEDecoder(embed_dim: int = 768, decoder_embed_dim: int = 512, output_dim: int = 768, num_patches: int = 196, depth: int = 4, num_heads: int = 16, mlp_ratio: float = 4.0, pos_embed_type: Literal['sincos_1d', 'sincos_2d', 'learned'] = 'sincos_2d', grid_size: int | None = None, drop_path_rate: float = 0.0)[source]#
Bases:
ModuleMAE-style Vision Transformer Decoder using FlexibleTransformer.
Implements the decoder component of Masked Autoencoders (MAE) [1] for self-supervised visual representation learning. The decoder reconstructs masked patches from visible patch embeddings using joint self-attention, where visible tokens and learnable mask tokens attend to each other. The decoder is intentionally lightweight compared to the encoder, as MAE demonstrates that a shallow decoder is sufficient for pixel reconstruction while keeping the encoder focused on learning semantic representations. Architecture Overview ——————— 1. Input projection: Maps encoder embeddings (embed_dim) to decoder
dimension (decoder_embed_dim)
Mask token expansion: Learnable mask tokens are placed at masked positions
Positional encoding: Adds position information to all tokens
Transformer blocks: Joint self-attention over visible + mask tokens
5. Output projection: Maps to output_dim (typically patch_size² × channels) :param embed_dim: Embedding dimension from the encoder. This is the input dimension
of visible tokens passed to the decoder.
- Parameters:
decoder_embed_dim (int, default=512) – Internal hidden dimension of the decoder transformer blocks. Typically smaller than embed_dim for efficiency.
output_dim (int, default=768) – Output dimension per token. For pixel reconstruction, this should be
patch_size ** 2 * in_channels(e.g., 16×16×3 = 768 for RGB).num_patches (int, default=196) – Total number of patches T in the image (e.g., 14×14 = 196 for 224×224 images with patch_size=16).
depth (int, default=4) – Number of transformer blocks in the decoder. MAE typically uses fewer blocks than the encoder (e.g., 4-8 vs 12-24).
num_heads (int, default=16) – Number of attention heads in multi-head self-attention.
mlp_ratio (float, default=4.0) – Expansion ratio for the MLP hidden dimension relative to decoder_embed_dim.
pos_embed_type ({'sincos_1d', 'sincos_2d', 'learned'}, default='sincos_2d') – Type of positional embedding: - ‘sincos_2d’: Fixed 2D sinusoidal (recommended for images) - ‘sincos_1d’: Fixed 1D sinusoidal - ‘learned’: Learnable positional embeddings
grid_size (int, optional) – Spatial grid size for 2D positional embeddings. If None, inferred as
int(sqrt(num_patches)). Required for non-square grids.drop_path_rate (float, default=0.0) – Stochastic depth rate for regularization during training.
Attributes
----------
mask_token (nn.Parameter) – Learnable token of shape (1, 1, embed_dim) used to represent masked positions. Initialized with truncated normal (std=0.02).
transformer (FlexibleTransformer) – Core transformer module handling attention and projections.
Notes
-----
MAE (- The mask convention follows)
positions (- The decoder receives visible tokens and reconstructs masked)
efficiency (- For)
default (>>> mask_ratio = 0.75 # MAE)
References
----------
He (.. [1]) – CVPR 2022. https://arxiv.org/abs/2111.06377
K. – CVPR 2022. https://arxiv.org/abs/2111.06377
Learners." (et al. "Masked Autoencoders Are Scalable Vision) – CVPR 2022. https://arxiv.org/abs/2111.06377
Examples
--------
Encoder** (**Basic Usage with MAE)
torch (>>> import)
nn (>>> import torch.nn as)
>>>
ViT-Base (>>> # Configuration matching)
B (>>>)
4 (T =)
size (196 # batch)
(14x14) (num_patches)
dimension (>>> embed_dim = 768 # encoder)
default
>>>
decoder (>>> # Initialize)
MAEDecoder( (>>> decoder =)
embed_dim=embed_dim (...)
:param : :param … decoder_embed_dim=512: :param : :param … output_dim=16 * 16 * 3: :param # patch_size² × channels = 768: :param … num_patches=T: :param : :param … depth=4: :param : :param … num_heads=16: :param : :param … ): :param >>>: :param >>> # Simulate encoder output (visible tokens only): :param >>> N_vis = int(T * (1 - mask_ratio)) # 49 visible patches: :param >>> visible_tokens = torch.randn(B: :param N_vis: :param embed_dim): :param >>>: :param >>> # Create random mask (0=visible: :param 1=masked): :param >>> mask = torch.zeros(B: :param T): :param >>> for i in range(B): :param … masked_indices = torch.randperm(T)[: :type … masked_indices = torch.randperm(T)[: T - N_vis] :param … mask[i: :param masked_indices] = 1: :param >>>: :param >>> # Decode - predict masked patches only: :param >>> pred_masked = decoder(visible_tokens: :param mask: :param output_masked_only=True): :param >>> print(pred_masked.shape) # [B: :param N_mask: :param output_dim]: :param torch.Size([4: :param 147: :param 768]): :param **Full Sequence Reconstruction**: :param >>> # Get predictions for ALL positions (for visualization): :param >>> pred_full = decoder(visible_tokens: :param mask: :param output_masked_only=False): :param >>> print(pred_full.shape) # [B: :param T: :param output_dim]: :param torch.Size([4: :param 196: :param 768]): :param **Using Full Sequence Input**: :param If you have the full sequence with mask tokens already inserted: :param >>> full_sequence = torch.randn(B: :param T: :param embed_dim) # [B: :param 196: :param 768]: :param >>> pred = decoder(full_sequence: :param mask: :param output_masked_only=True): :param >>> print(pred.shape): :param torch.Size([4: :param 147: :param 768]): :param **Integration with MAE Training Loop**: :param >>> # Typical MAE training step (pseudocode): :param >>> def mae_forward(encoder: :param decoder: :param images: :param mask_ratio=0.75): :param … # Patchify and mask: :param … patches = patchify(images) # [B: :param T: :param patch_dim]: :param … mask = random_mask(B: :param T: :param mask_ratio) # [B: :param T]: :param 0=keep: :param 1=mask: :param …: :param … # Encode visible patches only: :param … visible_patches = patches[~mask.bool()].reshape(B: :param -1: :param patch_dim): :param … latent = encoder(visible_patches) # [B: :param N_vis: :param embed_dim]: :param …: :param … # Decode to predict masked patches: :param … pred = decoder(: :param … latent: :param mask: :param output_masked_only=True: :param … ) # [B: :param N_mask: :param output_dim]: :param …: :param … # Reconstruction loss on masked patches only: :param … target = patches[mask.bool()].reshape(B: :param -1: :param patch_dim): :param … loss = F.mse_loss(pred: :param target): :param … return loss: :param **Custom Configuration for ViT-Large**: :param >>> decoder_large = MAEDecoder(: :param … embed_dim=1024: :param # ViT-L encoder dim: :param … decoder_embed_dim=512: :param # Keep decoder lightweight: :param … output_dim=768: :param # 16×16×3 pixels: :param … num_patches=256: :param # 16×16 patches for 256×256 images: :param … depth=8: :param # Slightly deeper: :param … num_heads=16: :param : :param … pos_embed_type=”sincos_2d”: :param : :param … drop_path_rate=0.1: :param # Regularization: :param … ): :param See Also: :param ——–: :param FlexibleTransformer: :type FlexibleTransformer: Core transformer implementation used internally.
- forward(x: Tensor, mask: Tensor, ids_keep: Tensor | None = None, output_masked_only: bool = False) Tensor[source]#
Forward pass.
- Parameters:
x – Visible tokens [B, N_vis, D] or full sequence [B, T, D]
mask – Binary mask [B, T], 0=kept, 1=masked
ids_keep – Indices of kept (visible) patches (B, N_keep)
output_masked_only – If True, return [B, N_mask, D]. If False, return [B, T, D].
- Returns:
Predictions
- class stable_pretraining.backbone.vit.MaskedEncoder(model_or_model_name: str | Module = 'vit_base_patch16_224', masking: PatchMasking | None = None, pretrained: bool = False, img_size: int | Tuple[int, int] | None = None, patch_size: int | Tuple[int, int] | None = None, dynamic_img_size: bool = False, norm_layer: Module | None = None)[source]#
Bases:
ModuleVision Transformer encoder with optional masking support.
Wraps a timm ViT model and adds flexible masking via
PatchMasking. Handles all ViT internals: patch embedding, positional embeddings, prefix tokens (CLS, registers), and transformer blocks. :param model_or_model_name: timm model name string or pre-instantiated nn.Module :param masking: PatchMasking instance. If None, no masking is applied. :param pretrained: Load pretrained weights (only when model_or_model_name is str) :param img_size: Override default image size :param patch_size: Override default patch size (will reinitialize patch_embed) :param dynamic_img_size: Enable dynamic image size support with pos_embed interpolation Example:from spt.backbone import PatchMasking, MaskedEncoder masking = PatchMasking(mask_ratio=0.75, block_size=4) encoder = MaskedEncoder( model_or_model_name="vit_base_patch16_224", masking=masking, pretrained=True, ) images = torch.randn(4, 3, 224, 224) output = encoder(images) print(output.encoded.shape) # (4, 1 + 49, 768) with 75% masking print(output.mask.shape) # (4, 196) print(output.ids_keep.shape) # (4, 49)
- extra_repr() str[source]#
Return the extra representation of the module.
To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.
- forward(images: Tensor) MaskedEncoderOutput[source]#
Encode images with optional masking.
- Parameters:
images – Input images (B, C, H, W)
- Returns:
MaskedEncoderOutput with encoded tokens and mask info
- class stable_pretraining.backbone.vit.MaskedEncoderOutput(encoded: Tensor = None, mask: Tensor = None, ids_keep: Tensor = None, grid_size: Tuple[int, int] = None)[source]#
Bases:
ModelOutputOutput from MaskedEncoder forward pass.
- Variables:
encoded – Encoded token representations (B, num_prefix + N_visible, D)
mask – Binary mask where 1 = masked, 0 = visible (B, N_patches)
ids_keep – Indices of visible patches (B, N_visible)
grid_size – Patch grid dimensions (height, width)
- class stable_pretraining.backbone.vit.PositionalEncoding2D(embed_dim: int, grid_size: Tuple[int, int], pos_type: Literal['learnable', 'sinusoidal', 'rope', 'none'] = 'learnable', num_prefix_tokens: int = 1, learnable: bool | None = None)[source]#
Bases:
ModuleFlexible 2D positional encoding for vision transformers.
- class stable_pretraining.backbone.vit.QKNorm(head_dim: int)[source]#
Bases:
ModuleQuery-Key Normalization for attention stabilization.
Applies LayerNorm (without learnable parameters) independently to query and key tensors before computing attention scores. This simple technique dramatically improves training stability in deep transformers. Why QK-Norm Works —————– In deep transformers, attention logits (Q·Kᵀ) can grow unboundedly large, causing: - Gradient explosion: Large logits → extreme softmax → tiny gradients - Attention collapse: All mass on single token - Training instability: Requires very small learning rates QK-Norm bounds the attention logits by normalizing Q and K to unit variance: - Attention logits become bounded:
|q·k| ≤ ||q|| ||k|| = O(√d)- Gradients remain stable throughout training - Enables larger learning rates and faster convergence Implementation Details ———————- - Uses LayerNorm without learnable parameters (γ=1, β=0) - Normalizes per-head: applied to[..., head_dim]dimension - Zero computational overhead in modern frameworks (fused with attention) :param head_dim: Dimension per attention head. Each head is normalizedindependently to preserve multi-head diversity.
- Example::
# In attention forward pass qk_norm = QKNorm(head_dim=64) # q, k shape: [B, num_heads, seq_len, head_dim] q, k = qk_norm(q, k) # Now safe to compute attention attn = (q @ k.transpose(-2, -1)) * scale
- Example integration with Attention::
- class Attention(nn.Module):
- def __init__(self, dim, num_heads, use_qk_norm=True):
… if use_qk_norm:
self.qk_norm = QKNorm(dim // num_heads)
- def forward(self, x):
q, k, v = self.qkv(x).chunk(3, dim=-1) if self.use_qk_norm:
q, k = self.qk_norm(q, k)
…
Note
QK-Norm is especially important when combined with: - SwiGLU: Gated activations can amplify hidden states - LayerScale: Small initial residual scale needs stable attention - Deep networks: Logit growth compounds with depth Without QK-Norm, these combinations often fail to train or require extensive hyperparameter tuning.
References
Henry et al., “Query-Key Normalization for Transformers” (EMNLP 2020)
Dehghani et al., “Scaling Vision Transformers to 22B Parameters” (2023)
Wortsman et al., “Small-scale proxies for large-scale Transformer training” (2023).
- extra_repr() str[source]#
Return the extra representation of the module.
To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.
- forward(q: Tensor, k: Tensor) tuple[Tensor, Tensor][source]#
Normalize query and key tensors independently.
- Parameters:
q – Query tensor of shape
[B, num_heads, seq_len, head_dim]or any shape with last dimension = head_dimk – Key tensor of same shape as q
- Returns:
Tuple of (normalized_q, normalized_k) with same shapes
Note
Normalization is applied to the last dimension (head_dim). Each head is normalized independently, preserving multi-head representation diversity.
- class stable_pretraining.backbone.vit.SwiGLU(in_features: int, hidden_features: int, out_features: int | None = None, bias: bool = False, drop: float = 0.0)[source]#
Bases:
ModuleSwiGLU: Gated Linear Unit with Swish activation.
A parameter-efficient gated activation that combines the benefits of gating mechanisms with the smooth, non-monotonic Swish activation. Empirically improves transformer performance over standard GeLU MLPs. Architecture ———— Standard MLP:
x → Linear → GeLU → Linear → out Parameters: 2 * d * h
- SwiGLU::
- x → Linear(W₁) → SiLU ─┐
├─ element-wise multiply → Linear(W₃) → out
x → Linear(W₂) ────────┘ Parameters: 3 * d * h’ (where h’ = 2h/3 to match param count)
The hidden dimension is scaled to
2/3 * hidden_featuresso that total parameter count matches a standard 2-layer MLP:3 * d * (2h/3) = 2 * d * hPerformance Benefits ——————– - Better gradient flow: Gating provides multiplicative paths - Smoother optimization: SiLU (Swish) is smooth and non-monotonic - Quality: Consistently outperforms GeLU in language and vision models :param in_features: Input dimension :param hidden_features: Nominal hidden dimension. Actual hidden size isint(2 * hidden_features / 3)to maintain parameter parity with standard MLPs.- Parameters:
out_features – Output dimension. Defaults to
in_features.bias – If True, use bias in linear layers. Default False following LLaMA/PaLM convention for better training stability.
drop – Dropout probability applied after gating.
- Example::
# Replace standard MLP in transformer # Old: mlp = Mlp(768, 3072, 768) # New: mlp = SwiGLU(768, 3072, 768) x = torch.randn(4, 196, 768) out = mlp(x) # [4, 196, 768] # Parameter count comparison standard_mlp = nn.Sequential(
nn.Linear(768, 3072), nn.GELU(), nn.Linear(3072, 768)
) swiglu = SwiGLU(768, 3072, 768) print(sum(p.numel() for p in standard_mlp.parameters())) # 4,722,432 print(sum(p.numel() for p in swiglu.parameters())) # 4,722,432 (same!)
Note
For best results, combine SwiGLU with: - LayerScale: Stabilizes residual connections - QK-Norm: Prevents attention explosion - RoPE: Better positional generalization This combination is used in LLaMA, PaLM, and modern vision transformers.
References
Shazeer, “GLU Variants Improve Transformer” (2020)
Touvron et al., “LLaMA: Open and Efficient Foundation Language Models” (2023)
- class stable_pretraining.backbone.vit.TransformerBlock(dim: int, num_heads: int, mlp_ratio: float = 4.0, self_attn: bool = True, cross_attn: bool = False, use_adaln: bool = False, use_rope: bool = False, use_qk_norm: bool = False, mlp_type: Literal['gelu', 'swiglu']='gelu', use_layer_scale: bool = False, layer_scale_init: float = 1e-05, drop_path: float = 0.0, attn_drop: float = 0.0, proj_drop: float = 0.0, max_grid_size: int = 32, act_layer: type = <class 'torch.nn.modules.activation.GELU'>)[source]#
Bases:
ModuleUnified transformer block supporting multiple architectures.
Configurable for various modern transformer designs:
Architecture
RoPE
QK-Norm
MLP
LayerScale
AdaLN
Standard ViT DINOv2 / Modern NEPA DiT / Flow
✗ ✓ ✓ ✗
✗ ✓ ✓ ✗
gelu swiglu swiglu gelu
✗ ✓ ✓ ✗
✗ ✗ ✗ ✓
Attention Modes#
- Mode 1: Self-Attention Only (
self_attn=True, cross_attn=False) Standard encoder block. Used for NEPA, ViT encoder, etc.
- Mode 2: Cross-Attention Only (
self_attn=False, cross_attn=True) Queries attend to context only. Lightweight decoder.
- Mode 3: Full Decoder (
self_attn=True, cross_attn=True) Self-attention on queries, then cross-attention to context.
Modern Components#
- RoPE (
use_rope=True): 2D Rotary Position Embedding. Encodes positions via Q/K rotation. Requires
grid_sizein forward(). Don’t use additive pos_embed.- QK-Norm (
use_qk_norm=True): Normalizes Q and K before attention. Stabilizes deep networks.
- SwiGLU (
mlp_type='swiglu'): Gated MLP with SiLU activation. Better than GeLU empirically.
- LayerScale (
use_layer_scale=True): Learnable per-channel scaling on residuals. Stabilizes training. Initialize near zero (e.g., 1e-5) for identity-like initialization.
- param dim:
Hidden dimension
- param num_heads:
Number of attention heads
- param mlp_ratio:
MLP hidden dim = dim * mlp_ratio
- param self_attn:
Enable self-attention
- param cross_attn:
Enable cross-attention
- param use_adaln:
Enable AdaLN-Zero conditioning (for diffusion/flow)
- param use_rope:
Enable 2D Rotary Position Embedding in attention
- param use_qk_norm:
Enable Query-Key normalization in attention
- param mlp_type:
MLP activation type: ‘gelu’ or ‘swiglu’
- param use_layer_scale:
Enable LayerScale on residual connections
- param layer_scale_init:
Initial value for LayerScale (default: 1e-5)
- param drop_path:
Stochastic depth rate
- param attn_drop:
Attention dropout rate
- param proj_drop:
Projection dropout rate
- param max_grid_size:
Maximum grid size for RoPE cache
Example:
# Standard ViT block block = TransformerBlock(dim=768, num_heads=12) # Modern ViT (DINOv2-style) block = TransformerBlock( dim=768, num_heads=12, use_rope=True, use_qk_norm=True, mlp_type="swiglu", use_layer_scale=True, ) out = block(x, grid_size=(14, 14)) # NEPA block (modern + causal) causal_mask = torch.triu(torch.ones(N, N, dtype=torch.bool), diagonal=1) out = block(x, grid_size=(14, 14), attn_mask=causal_mask) # DiT block (with conditioning) block = TransformerBlock(dim=768, num_heads=12, use_adaln=True) out = block(x, cond=time_emb)
- forward(x: Tensor, context: Tensor | None = None, cond: Tensor | None = None, attn_mask: Tensor | None = None, cross_attn_mask: Tensor | None = None, grid_size: Tuple[int, int] | None = None) Tensor[source]#
Forward pass.
- Parameters:
x – Input tensor [B, N, D]
context – Context for cross-attention [B, M, D]
cond – Conditioning tensor [B, D] (required if use_adaln=True)
attn_mask – Self-attention mask. True = blocked.
cross_attn_mask – Cross-attention mask. True = blocked.
grid_size – (grid_h, grid_w) for RoPE. Required if use_rope=True.
- Returns:
Output tensor [B, N, D]
- Mode 1: Self-Attention Only (