stable_pretraining.backbone package

Contents

stable_pretraining.backbone package#

Submodules#

stable_pretraining.backbone.aggregator module#

Modular tensor aggregation module for feeding multi-scale/multi-layer features to MLPs.

Commonly used for: - SSL linear probes using multiple transformer layers - Multi-scale feature fusion - Combining features from different network stages

class stable_pretraining.backbone.aggregator.TensorAggregator(input_spec: str | List[str] | Dict[str, str], adaptive_pool_size: int = 1)[source]#

Bases: Module

Aggregates multi-dimensional tensors into 2D format for MLP input.

Pure aggregation module with NO trainable parameters. Handles various input formats and aggregation strategies.

Parameters:
  • input_spec – Specification of input format and aggregation modes: - str: Single aggregation mode for all tensors (e.g., “mean”) - List[str]: Per-tensor aggregation modes for list inputs - Dict[str, str]: Per-key aggregation modes for dict inputs

  • adaptive_pool_size – Output size for adaptive pooling (default: 1)

Aggregation Modes:
  • “mean”: Spatial/temporal mean pooling

  • “max”: Spatial/temporal max pooling

  • “cls”: Take first token (for transformers with [CLS] token)

  • “flatten”: Flatten all dimensions after batch

  • “adaptive”: Adaptive average pooling to fixed size

Examples

>>> # Single tensor with mean pooling
>>> agg = TensorAggregator("mean")
>>> x = torch.randn(4, 768, 14, 14)
>>> out = agg(x)  # Shape: (4, 768)
>>> # SSL: Last 4 transformer layers with CLS token
>>> agg = TensorAggregator(["cls", "cls", "cls", "cls"])
>>> layers = [torch.randn(4, 197, 768) for _ in range(4)]
>>> out = agg(layers)  # Shape: (4, 3072)  # 768 * 4
>>> # Multi-scale features
>>> agg = TensorAggregator({"layer1": "cls", "layer2": "mean", "conv": "mean"})
>>> out = agg(
...     {
...         "layer1": torch.randn(4, 197, 768),
...         "layer2": torch.randn(4, 197, 768),
...         "conv": torch.randn(4, 512, 14, 14),
...     }
... )  # Shape: (4, 2048)
compute_output_dim(input_shapes: tuple | List[tuple] | Dict[str, tuple]) int[source]#

Compute the output dimension given input shapes.

Parameters:

input_shapes – Shape(s) of input tensor(s) (excluding batch dim)

Returns:

Total output features

Examples

>>> agg = TensorAggregator(["cls", "mean"])
>>> agg.compute_output_dim([(197, 768), (197, 768)])
1536
>>> agg = TensorAggregator({"l1": "cls", "conv": "mean"})
>>> agg.compute_output_dim({"l1": (197, 768), "conv": (512, 14, 14)})
1280
forward(x: Tensor | List[Tensor] | Dict[str, Tensor]) Tensor[source]#

Aggregate input tensor(s) to 2D format.

Parameters:

x – Input tensor, list of tensors, or dict of tensors

Returns:

Aggregated 2D tensor of shape (B, total_features)

stable_pretraining.backbone.convmixer module#

class stable_pretraining.backbone.convmixer.ConvMixer(in_channels=3, num_classes=10, dim=64, depth=6, kernel_size=9, patch_size=7)[source]#

Bases: Module

ConvMixer model.

A simple and efficient convolutional architecture that operates directly on patches.

Parameters:
  • in_channels (int, optional) – Number of input channels. Defaults to 3.

  • num_classes (int, optional) – Number of output classes. Defaults to 10.

  • dim (int, optional) – Hidden dimension size. Defaults to 64.

  • depth (int, optional) – Number of ConvMixer blocks. Defaults to 6.

  • kernel_size (int, optional) – Kernel size for depthwise convolution. Defaults to 9.

  • patch_size (int, optional) – Patch embedding size. Defaults to 7.

Note

Introduced in [Trockman and Kolter, 2022].

forward(xb)[source]#

Forward pass through the ConvMixer model.

Parameters:

xb (torch.Tensor) – Input tensor of shape (batch_size, in_channels, height, width).

Returns:

Output logits of shape (batch_size, num_classes).

Return type:

torch.Tensor

stable_pretraining.backbone.mlp module#

class stable_pretraining.backbone.mlp.MLP(in_channels: int, hidden_channels: list[int], norm_layer: str = None, activation_layer=<class 'torch.nn.modules.activation.ReLU'>, inplace: bool = None, bias: bool = True, dropout: float = 0.0)[source]#

Bases: Sequential

This block implements the multi-layer perceptron (MLP) module.

Parameters:
  • in_channels (int) – Number of channels of the input

  • hidden_channels (List[int]) – List of the hidden channel dimensions

  • norm_layer (Callable[..., torch.nn.Module], optional) – Norm layer that will be stacked on top of the linear layer. If None this layer won’t be used. Default: None

  • activation_layer (Callable[..., torch.nn.Module], optional) – Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the linear layer. If None this layer won’t be used. Default: torch.nn.ReLU

  • inplace (bool, optional) – Parameter for the activation layer, which can optionally do the operation in-place. Default is None, which uses the respective default values of the activation_layer and Dropout layer.

  • bias (bool) – Whether to use bias in the linear layer. Default True

  • dropout (float) – The probability for the dropout layer. Default: 0.0

stable_pretraining.backbone.patch_masking module#

Patch masking strategies for masked image modeling.

class stable_pretraining.backbone.patch_masking.IJEPAMaskOutput(context_idx: Tensor = None, target_idx: Tensor = None, target_block_masks: List[Tensor] = None, mask: Tensor = None)[source]#

Bases: ModelOutput

Output from I-JEPA masking operation.

Variables:
  • context_idx – Indices of context (visible) patches [B, N_ctx]

  • target_idx – Combined indices of all target patches [B, N_tgt]

  • target_block_masks – Per-block boolean masks [M x [B, N]], True = in this block

  • mask – Full mask where 1 = target, 0 = context [B, N]

context_idx: Tensor = None#
mask: Tensor = None#
target_block_masks: List[Tensor] = None#
target_idx: Tensor = None#
class stable_pretraining.backbone.patch_masking.IJEPAMasking(num_targets: int = 4, target_scale: Tuple[float, float] = (0.15, 0.2), target_aspect_ratio: Tuple[float, float] = (0.75, 1.5), context_scale: Tuple[float, float] = (0.85, 1.0), allow_target_overlap: bool = False)[source]#

Bases: Module

I-JEPA multi-block masking for joint-embedding predictive architecture.

Samples M non-overlapping target blocks and a context region that excludes all targets. This is the key masking strategy from I-JEPA [1]. Strategy:

  1. Sample M target blocks with specified scale and aspect ratio

  2. Context = all patches NOT in any target block

  3. Optionally subsample context to specified ratio

Parameters:
  • num_targets – Number of target blocks to sample (default: 4)

  • target_scale – (min, max) fraction of patches per target block

  • target_aspect_ratio – (min, max) aspect ratio of target blocks

  • context_scale – (min, max) fraction of non-target patches to keep as context

  • allow_target_overlap – Allow target blocks to overlap (default: False)

Example::
masking = IJEPAMasking(

num_targets=4, target_scale=(0.15, 0.2), target_aspect_ratio=(0.75, 1.5), context_scale=(0.85, 1.0),

)

# x: patch embeddings [B, N, D] output = masking(x, grid_h=14, grid_w=14)

context_patches = x.gather(

1, output.context_idx.unsqueeze(-1).expand(-1, -1, D)

) target_patches = x.gather(1, output.target_idx.unsqueeze(-1).expand(-1, -1, D))

References

extra_repr() str[source]#

Return the extra representation of the module.

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

forward(x: Tensor, grid_h: int, grid_w: int) IJEPAMaskOutput[source]#

Apply I-JEPA masking.

Parameters:
  • x – Patch embeddings [B, N, D] where N = grid_h * grid_w

  • grid_h – Height of patch grid

  • grid_w – Width of patch grid

Returns:

IJEPAMaskOutput with context/target information

Note

Always returns exactly num_targets block masks. If overlap prevention makes it impossible to fit all blocks, some masks will be empty (all False). The combined target_idx only includes patches from non-empty blocks.

class stable_pretraining.backbone.patch_masking.MaskingOutput(visible: Tensor = None, mask: Tensor = None, ids_restore: Tensor = None, ids_keep: Tensor = None)[source]#

Bases: ModelOutput

Output from patch masking operation.

Variables:
  • visible – Visible patch embeddings (B, N_keep, D)

  • mask – Binary mask where 1 = masked, 0 = visible (B, N)

  • ids_restore – Indices to restore original order (B, N)

  • ids_keep – Indices of kept (visible) patches (B, N_keep)

ids_keep: Tensor = None#
ids_restore: Tensor = None#
mask: Tensor = None#
visible: Tensor = None#
class stable_pretraining.backbone.patch_masking.MultiBlockMasking(num_targets: int = 4, context_scale: Tuple[float, float] = (0.85, 1.0), target_scale: Tuple[float, float] = (0.15, 0.2), context_aspect_ratio: Tuple[float, float] = (1.0, 1.0), target_aspect_ratio: Tuple[float, float] = (0.75, 1.5))[source]#

Bases: Module

Multi-block masking for SALT Stage 1 (VPixel).

Generates one large context block and M target blocks using multi_block_mask(), then makes context disjoint from targets. Returns MaskingOutput compatible with MaskedEncoder.

Parameters:
  • num_targets – Number of target blocks (default: 4)

  • context_scale – (min, max) scale for context block (default: (0.85, 1.0))

  • target_scale – (min, max) scale for each target block (default: (0.15, 0.2))

  • context_aspect_ratio – (min, max) aspect ratio for context (default: (1.0, 1.0))

  • target_aspect_ratio – (min, max) aspect ratio for targets (default: (0.75, 1.5))

Example:

masking = MultiBlockMasking(num_targets=4)
output = masking(patch_embeddings, grid_h=14, grid_w=14)

visible_patches = output.visible  # (B, N_keep, D)
mask = output.mask  # (B, N), 1=masked, 0=visible
extra_repr() str[source]#

Return the extra representation of the module.

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

forward(x: Tensor, grid_h: int, grid_w: int) MaskingOutput[source]#

Apply multi-block masking to patch embeddings.

Parameters:
  • x – Patch embeddings [B, N, D] where N = grid_h * grid_w

  • grid_h – Height of patch grid

  • grid_w – Width of patch grid

Returns:

MaskingOutput with visible patches and mask info

class stable_pretraining.backbone.patch_masking.PatchMasking(mask_ratio: float = 0.75, block_size: int = 1, crop_ratio: float = 0.0, crop_aspect_ratio: tuple[float, float] = (0.75, 1.33))[source]#

Bases: Module

Flexible patch masking module for masked image modeling.

Supports three masking strategies that are selected stochastically:

  • Random: Uniformly random patch selection (when block_size=1)

  • Block: Square blocks of adjacent patches (when block_size > 1)

  • Crop: Rectangular crop region, remaining patches masked (when crop_ratio > 0)

Strategy selection per sample:

  1. With probability crop_ratio, use crop masking

  2. Otherwise, if block_size > 1, use block masking

  3. Otherwise, use random masking

Parameters:
  • mask_ratio – Fraction of patches to mask, in [0, 1)

  • block_size – Size of square blocks for block masking (1 = random masking)

  • crop_ratio – Probability of using crop masking vs block/random

  • crop_aspect_ratio – (min, max) aspect ratio range for crop regions

Example:

masking = PatchMasking(mask_ratio=0.75, block_size=4)
output = masking(patch_embeddings, grid_h=14, grid_w=14)

visible_patches = output.visible  # (B, N_keep, D)
mask = output.mask  # (B, N), 1=masked, 0=visible
ids_keep = output.ids_keep  # (B, N_keep)
extra_repr() str[source]#

Return the extra representation of the module.

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

forward(x: Tensor, grid_h: int, grid_w: int) MaskingOutput[source]#

Apply masking to patch embeddings.

Parameters:
  • x – Patch embeddings of shape (B, N, D) where N = grid_h * grid_w

  • grid_h – Height of the patch grid

  • grid_w – Width of the patch grid

Returns:

MaskingOutput containing visible patches and mask information

Raises:
  • ValueError – If x.shape[1] != grid_h * grid_w

  • ValueError – If input tensor has wrong number of dimensions

stable_pretraining.backbone.pos_embed module#

Positional embedding utilities for vision transformers.

stable_pretraining.backbone.pos_embed.get_1d_sincos_pos_embed(embed_dim: int, length: int, cls_token: bool = False) Tensor[source]#

Generate 1D sinusoidal positional embeddings.

Parameters:
  • embed_dim – Embedding dimension

  • length – Sequence length (number of positions)

  • cls_token – If True, prepend a zero embedding for CLS token

Returns:

Positional embeddings of shape (length, embed_dim) or (length + 1, embed_dim) if cls_token=True

stable_pretraining.backbone.pos_embed.get_2d_sincos_pos_embed(embed_dim: int, grid_size: int | tuple[int, int], cls_token: bool = False) Tensor[source]#

Generate 2D sinusoidal positional embeddings for image patches.

Parameters:
  • embed_dim – Embedding dimension (must be divisible by 4)

  • grid_size – Grid height/width as int (square) or (height, width) tuple

  • cls_token – If True, prepend a zero embedding for CLS token

Returns:

Positional embeddings of shape (H*W, embed_dim) or (H*W + 1, embed_dim) if cls_token=True

stable_pretraining.backbone.pos_embed.get_sincos_pos_embed(embed_dim: int, num_patches: int, mode: Literal['1d', '2d'] = '1d', grid_size: int | tuple[int, int] | None = None, cls_token: bool = False) Tensor[source]#

Unified interface for generating sinusoidal positional embeddings.

Parameters:
  • embed_dim – Embedding dimension

  • num_patches – Total number of patches (used for 1d mode)

  • mode – Embedding type - ‘1d’ for sequence, ‘2d’ for image grid

  • grid_size – Required for ‘2d’ mode

  • cls_token – If True, prepend a zero embedding for CLS token

Returns:

Positional embeddings tensor

stable_pretraining.backbone.pos_embed.get_timestep_embed(t: Tensor, dim: int, max_period: int = 10000) Tensor[source]#

Generate sinusoidal embeddings for continuous timesteps.

Unlike positional embeddings for sequences, this embeds scalar timestep values. Used for diffusion/flow matching time conditioning. :param t: Timestep values (B,) or (B, 1), typically in [0, 1] :param dim: Embedding dimension :param max_period: Maximum period for frequency scaling :return: Timestep embeddings of shape (B, dim)

stable_pretraining.backbone.pos_embed.interpolate_pos_embed(pos_embed: Tensor, src_size: tuple[int, int], tgt_size: tuple[int, int], num_prefix_tokens: int = 0, mode: str = 'bicubic') Tensor[source]#

Interpolate positional embeddings to a new grid size.

Parameters:
  • pos_embed – Original positional embeddings of shape (1, num_prefix + src_h*src_w, embed_dim) or (num_prefix + src_h*src_w, embed_dim)

  • src_size – Source grid size as (height, width)

  • tgt_size – Target grid size as (height, width)

  • num_prefix_tokens – Number of prefix tokens (CLS, registers) to preserve

  • mode – Interpolation mode (‘nearest’, ‘bilinear’, ‘bicubic’, ‘area’)

Returns:

Interpolated positional embeddings

Example:

old_pos = model.pos_embed  # (1, 197, 768) = 1 + 14*14
new_pos = interpolate_pos_embed(
    old_pos, src_size=(14, 14), tgt_size=(16, 16), num_prefix_tokens=1
)  # (1, 257, 768) = 1 + 16*16

stable_pretraining.backbone.probe module#

class stable_pretraining.backbone.probe.AutoLinearClassifier(name, embedding_dim, num_classes, pooling=None, weight_decay=[0], lr_scaling=[1], normalization=['none', 'norm', 'bn'], dropout=[0, 0.5], label_smoothing=[0, 1])[source]#

Bases: Module

Linear using either CLS token or mean pooling with configurable normalization layer.

Parameters:
  • embedding_dim (int) – Dimensionality of the input embeddings.

  • num_classes (int) – Number of output classes.

  • pooling (str) – Pooling strategy, either ‘cls’ or ‘mean’.

  • norm_layer (callable or None) – Normalization layer class (e.g., torch.nn.LayerNorm, torch.nn.BatchNorm1d), or None for no normalization. Should accept a single argument: normalized_shape or num_features.

norm#

Instantiated normalization layer, or None.

Type:

nn.Module or None

fc#

Linear layer mapping pooled representation to class logits.

Type:

nn.Linear

Forward Args:
x (torch.Tensor): Input tensor of shape (N, T, D) or (N, D).

If 3D, pooling and normalization are applied. If 2D, input is used directly (no pooling or normalization).

Returns:

Output logits of shape (N, num_classes).

Return type:

torch.Tensor

Example

>>> probe = LinearProbe(
...     embedding_dim=128,
...     num_classes=10,
...     pooling="mean",
...     norm_layer=torch.nn.LayerNorm,
... )
>>> x = torch.randn(32, 20, 128)
>>> logits = probe(x)  # shape: (32, 10)
>>> x2 = torch.randn(32, 128)
>>> logits2 = probe(x2)  # shape: (32, 10)
forward(x, y=None, pl_module=None)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class stable_pretraining.backbone.probe.AutoTuneMLP(in_features: int, out_features: int, hidden_features: List[int] | List[List[int]], name: str, loss_fn: Callable, additional_weight_decay: float | List[float] = [0], lr_scaling: float | List[float] = [1], normalization: str | List[str] = ['none'], dropout: float | List[float] = [0], activation: str | List[str] = ['relu'])[source]#

Bases: Module

Automatically creates multiple MLP variants with different hyperparameter combinations.

This module creates a grid of MLPs with different configurations (dropout, normalization, learning rates, architectures, etc.) to enable parallel hyperparameter tuning.

Parameters:
  • in_features – Number of input features

  • out_features – Number of output features

  • hidden_features – Architecture specification. Can be: - List[int]: Single architecture, e.g., [256, 128] - List[List[int]]: Multiple architectures, e.g., [[256, 128], [512, 256, 128]] - []: Empty list for linear model (no hidden layers)

  • name – Base name for this AutoTuneMLP instance

  • loss_fn – Loss function to compute loss

  • additional_weight_decay – List of weight decay values to try

  • lr_scaling – List of learning rate scaling factors to try

  • normalization – List of normalization types [‘none’, ‘norm’, ‘bn’]

  • dropout – List of dropout rates to try

  • activation – List of activation functions [‘relu’, ‘leaky_relu’, ‘tanh’]

Examples

>>> # Single architecture
>>> model = AutoTuneMLP(128, 10, [256, 128], "clf", nn.CrossEntropyLoss())
>>> # Multiple architectures
>>> model = AutoTuneMLP(
...     128, 10, [[256], [256, 128], [512, 256]], "clf", nn.CrossEntropyLoss()
... )
>>> # Linear model (no hidden layers)
>>> model = AutoTuneMLP(128, 10, [], "linear_clf", nn.CrossEntropyLoss())
forward(x: Tensor, y: Tensor | None = None) Dict[str, Tensor][source]#

Forward pass through all MLP variants.

Parameters:
  • x – Input tensor of shape (batch_size, in_features)

  • y – Optional target tensor for loss computation

Returns:

Dictionary with predictions and losses for each variant Format: {‘pred/{variant_id}’: tensor, ‘loss/{variant_id}’: tensor}

get_best_variant(metric_dict: Dict[str, float], lower_is_better: bool = True) str[source]#

Get the best performing variant based on metrics.

Parameters:
  • metric_dict – Dictionary mapping variant_id to metric values

  • lower_is_better – If True, lower metric is better (e.g., loss). If False, higher is better (e.g., accuracy)

Returns:

ID of the best performing variant

get_variant(key: str) Module[source]#

Get a specific MLP variant by key.

Parameters:

key – Variant ID

Returns:

The MLP module

Raises:

KeyError – If key doesn’t exist

keys() List[str][source]#

Get list of all MLP variant names.

Returns:

List of variant IDs (strings)

Example

>>> model = AutoTuneMLP(
...     128, 10, [[256], [512]], "clf", nn.CrossEntropyLoss()
... )
>>> model.keys()
['clf_arch0_256_none_relu_drop0_lr1_wd0', 'clf_arch1_512_none_relu_drop0_lr1_wd0']
num_variants() int[source]#

Get the number of MLP variants.

class stable_pretraining.backbone.probe.LinearProbe(embedding_dim, num_classes, pooling='cls', norm_layer=None)[source]#

Bases: Module

Linear using either CLS token or mean pooling with configurable normalization layer.

Parameters:
  • embedding_dim (int) – Dimensionality of the input embeddings.

  • num_classes (int) – Number of output classes.

  • pooling (str) – Pooling strategy, either ‘cls’ or ‘mean’.

  • norm_layer (callable or None) – Normalization layer class (e.g., torch.nn.LayerNorm, torch.nn.BatchNorm1d), or None for no normalization. Should accept a single argument: normalized_shape or num_features.

norm#

Instantiated normalization layer, or None.

Type:

nn.Module or None

fc#

Linear layer mapping pooled representation to class logits.

Type:

nn.Linear

Forward Args:
x (torch.Tensor): Input tensor of shape (N, T, D) or (N, D).

If 3D, pooling and normalization are applied. If 2D, input is used directly (no pooling or normalization).

Returns:

Output logits of shape (N, num_classes).

Return type:

torch.Tensor

Example

>>> probe = LinearProbe(
...     embedding_dim=128,
...     num_classes=10,
...     pooling="mean",
...     norm_layer=torch.nn.LayerNorm,
... )
>>> x = torch.randn(32, 20, 128)
>>> logits = probe(x)  # shape: (32, 10)
>>> x2 = torch.randn(32, 128)
>>> logits2 = probe(x2)  # shape: (32, 10)
forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class stable_pretraining.backbone.probe.MultiHeadAttentiveProbe(embedding_dim: int, num_classes: int, num_heads: int = 4)[source]#

Bases: Module

A multi-head attentive probe for sequence representations.

This module applies multiple attention heads to a sequence of embeddings, pools the sequence into a fixed-size representation per head, concatenates the results, and projects to a set of output classes.

Parameters:
  • embedding_dim (int) – Dimensionality of the input embeddings.

  • num_classes (int) – Number of output classes.

  • num_heads (int, optional) – Number of attention heads. Default is 4.

ln#

Layer normalization applied to the input.

Type:

torch.nn.LayerNorm

attn_vectors#

Learnable attention vectors for each head, shape (num_heads, embedding_dim).

Type:

torch.nn.Parameter

fc#

Final linear layer mapping concatenated head outputs to class logits.

Type:

torch.nn.Linear

Forward Args:
x (torch.Tensor): Input tensor of shape (N, T, D), where

N = batch size, T = sequence length, D = embedding_dim.

Returns:

Output logits of shape (N, num_classes).

Return type:

torch.Tensor

Example

>>> probe = MultiHeadAttentiveProbe(
...     embedding_dim=128, num_classes=10, num_heads=4
... )
>>> x = torch.randn(32, 20, 128)  # batch of 32, sequence length 20
>>> logits = probe(x)  # shape: (32, 10)
forward(x: Tensor)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

stable_pretraining.backbone.resnet9 module#

class stable_pretraining.backbone.resnet9.MLP(in_channels: int, hidden_channels: list[int], norm_layer: str = None, activation_layer=<class 'torch.nn.modules.activation.ReLU'>, inplace: bool = None, bias: bool = True, dropout: float = 0.0)[source]#

Bases: Sequential

This block implements the multi-layer perceptron (MLP) module.

Parameters:
  • in_channels (int) – Number of channels of the input

  • hidden_channels (List[int]) – List of the hidden channel dimensions

  • norm_layer (Callable[..., torch.nn.Module], optional) – Norm layer that will be stacked on top of the linear layer. If None this layer won’t be used. Default: None

  • activation_layer (Callable[..., torch.nn.Module], optional) – Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the linear layer. If None this layer won’t be used. Default: torch.nn.ReLU

  • inplace (bool, optional) – Parameter for the activation layer, which can optionally do the operation in-place. Default is None, which uses the respective default values of the activation_layer and Dropout layer.

  • bias (bool) – Whether to use bias in the linear layer. Default True

  • dropout (float) – The probability for the dropout layer. Default: 0.0

class stable_pretraining.backbone.resnet9.ResidualBlock(in_channels, out_channels, kernel_size, padding, stride)[source]#

Bases: Module

A residual block as defined by He et al.

forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class stable_pretraining.backbone.resnet9.Resnet9(num_classes, num_channels, *args, **kwargs)[source]#

Bases: Module

A Residual network.

forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

stable_pretraining.backbone.utils module#

class stable_pretraining.backbone.utils.EfficientMaskedTimmViT(vit: Module)[source]#

Bases: Module

Optimized Vision Transformer wrapper that efficiently handles NaN patches.

This module is designed to work with timm ViT models and provides: - Per-sample NaN masking (different NaN patterns per image in batch) - Fast path for same masking pattern across batch - Support for class tokens (cls_token), distillation tokens (dist_token), and register tokens - Compatibility with various timm ViT architectures (vit_*, deit_*, beit_*, etc.) - Minimal overhead when no masking is present

Key Optimizations: - Early exit when no NaN patches detected - Simpler indexing for same masking patterns - Cached batch indices for repeated operations - Zero-copy operations where possible

Parameters:

vit – A timm Vision Transformer model instance

Raises:

Example

>>> import timm
>>> vit = timm.create_model(
...     "vit_base_patch16_224", pretrained=False, reg_tokens=4
... )
>>> masked_vit = EfficientMaskedTimmViT(vit)
>>>
>>> # Create input with some NaN patches
>>> x = torch.randn(4, 3, 224, 224)
>>> output = masked_vit(x)
Performance:
  • Same pattern masking: ~0-5% overhead vs different patterns

  • No masking: <2% overhead vs original model

  • 50% masking: ~1.5x speedup

  • 90% masking: ~2.5-3x speedup

Note

All samples in a batch must have the same NUMBER of NaN patches, but the LOCATION of NaN patches can differ per sample.

Register tokens (DINOv2 style) do NOT receive positional embeddings.

clear_cache()[source]#

Clear the cached batch indices.

Useful if you want to free memory after processing different batch sizes. The cache will be rebuilt as needed during forward passes.

forward(x: Tensor) Tensor[source]#

Forward pass through the masked ViT.

This method implements an optimized forward pass with the following features: - Early exit for inputs without NaN patches (fast path) - Optimized indexing for same masking patterns across batch - Per-sample masking support with advanced indexing - Automatic NaN replacement for partial NaN patches - Support for register tokens (DINOv2 style)

Parameters:
  • x – Input tensor, either:

  • images (- Raw) – shape (B, C, H, W)

  • Pre-patchified (-) – shape (B, N, D) where N is number of patches

Returns:

Model output (logits if head exists, features otherwise)

Return type:

torch.Tensor

Raises:
  • ValueError – If samples have different numbers of NaN patches

  • ValueError – If all patches are NaN

Performance Notes:
  • No NaN patches: Uses fast path with <2% overhead

  • Same pattern: Optimized indexing, ~0-5% overhead vs different patterns

  • Different patterns: Uses advanced indexing, ~10-35% slower at high masking

class stable_pretraining.backbone.utils.EmbeddingOutput(last_hidden_state: Any = None, hidden_states: dict[str, Tensor] = None)[source]#

Bases: ModelOutput

HuggingFace-style output container for model embeddings.

last_hidden_state#

The final output from the backbone model.

Type:

Any

hidden_states#

Dictionary mapping layer names to their intermediate outputs.

Type:

dict[str, torch.Tensor]

hidden_states: dict[str, Tensor] = None#
last_hidden_state: Any = None#
class stable_pretraining.backbone.utils.EvalOnly(backbone: Module)[source]#

Bases: Module

Wrapper that forces a module to remain in evaluation mode.

forward(*args, **kwargs)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

train(mode)[source]#

Set the module in training mode.

This has an effect only on certain modules. See the documentation of particular modules for details of their behaviors in training/evaluation mode, i.e., whether they are affected, e.g. Dropout, BatchNorm, etc.

Parameters:

mode (bool) – whether to set training mode (True) or evaluation mode (False). Default: True.

Returns:

self

Return type:

Module

class stable_pretraining.backbone.utils.FeaturesConcat(agg: callable, names: str | Iterable[str] = None)[source]#

Bases: Module

Aggregates and concatenates features from a dictionary input, then classifies.

Parameters:

names (List[str]) – Keys to extract from the input dictionary. if not given then we aggregate everything from dict/list

forward(inputs: dict | Iterable)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

static get_output_shape(agg: callable, shapes: list[str] | Dict[str, Iterable[int]])[source]#

Given a list of shapes (tuples), returns the expected concatenated shape.

Assumes all shapes have the same batch size (shapes[0][0]).

Parameters:
  • shapes (List[Tuple[int]]) – List of shapes after aggregation.

  • agg (callable) – How to aggregate, can be None.

Returns:

The concatenated shape.

Return type:

Tuple[int]

class stable_pretraining.backbone.utils.HiddenStateExtractor(backbone: Module, module_names: list[str])[source]#

Bases: Module

Wrapper that captures intermediate embeddings from specified layers.

Returns outputs in HuggingFace Transformers style with last_hidden_state for the final backbone output and hidden_states for intermediate layers.

Parameters:
  • backbone – The neural network module to wrap.

  • module_names – List of module names to capture (e.g., [‘layer1’, ‘encoder.block1’]). Supports nested modules using dot notation.

Returns:

  • last_hidden_state: Final backbone output

  • hidden_states: Dict mapping module names to their outputs

Return type:

EmbeddingOutput with

Example

>>> model = ReturnEmbedding(
...     torchvision.models.swin_v2_s(),
...     ["features.0", "features.2", "features.4"],
... )
>>> output = model(images)
>>> output.last_hidden_state  # final output
>>> output.hidden_states["features.2"]  # intermediate layer
Raises:

ValueError – If any module name is not found in the backbone.

forward(*args, **kwargs) EmbeddingOutput[source]#

Run forward pass and return embeddings in HuggingFace style.

remove_hooks() None[source]#

Remove all registered hooks to free resources.

class stable_pretraining.backbone.utils.TeacherStudentWrapper(student: Module, warm_init: bool = True, base_ema_coefficient: float = 0.996, final_ema_coefficient: float = 1)[source]#

Bases: Module

Backbone wrapper that implements teacher-student distillation via EMA.

This is a wrapper for backbones that creates a teacher model as an exponential moving average (EMA) of the student model. It should be passed as the backbone to stable_pretraining.Module and accessed via forward_student() and forward_teacher() methods in your custom forward function.

The teacher model is updated by taking a running average of the student’s parameters and buffers. When ema_coefficient == 0.0, the teacher and student are literally the same object, saving memory but forward passes through the teacher will not produce any gradients.

Usage example:

backbone = ResNet18() wrapped_backbone = TeacherStudentWrapper(backbone) module = ssl.Module(

backbone=wrapped_backbone, projector=projector, forward=forward_with_teacher_student, …

)

Parameters:
  • student (torch.nn.Module) – The student model whose parameters will be tracked.

  • warm_init (bool, optional) – If True, performs an initialization step to match the student’s parameters immediately. Default is True.

  • base_ema_coefficient (float, optional) – EMA decay factor at the start of training. This value will be updated following a cosine schedule. Should be in [0, 1]. A value of 0.0 means the teacher is fully updated to the student’s parameters on every step, while a value of 1.0 means the teacher remains unchanged. Default is 0.996.

  • final_ema_coefficient (float, optional) – EMA decay factor at the end of training. Default is 1.

forward(*args, **kwargs)[source]#

Forward pass through either the student or teacher network.

You can choose which model to run in the default forward. Commonly the teacher is evaluated, so we default to that.

forward_student(*args, **kwargs)[source]#

Forward pass through the student network. Gradients will flow normally.

forward_teacher(*args, **kwargs)[source]#

Forward pass through the teacher network.

By default, the teacher network does not require grad. If ema_coefficient == 0, then teacher==student, so we wrap in torch.no_grad() to ensure no gradients flow.

update_ema_coefficient(epoch: int, total_epochs: int)[source]#

Update the EMA coefficient following a cosine schedule.

The EMA coefficient is updated following a cosine schedule:

ema_coefficient = final_ema_coefficient - 0.5 * (final_ema_coefficient - base_ema_coefficient) * (1 + cos(epoch / total_epochs * pi))

Parameters:
  • epoch (int) – Current epoch in the training loop.

  • total_epochs (int) – Total number of epochs in the training loop.

update_teacher()[source]#

Perform one EMA update step on the teacher’s parameters.

The update rule is:

teacher_param = ema_coefficient * teacher_param + (1 - ema_coefficient) * student_param

This is done in a no_grad context to ensure the teacher’s parameters do not accumulate gradients, but the student remains fully trainable.

Everything is updated, including buffers (e.g. batch norm running averages).

stable_pretraining.backbone.utils.from_huggingface(model_name, pretrained, attn_implementation='sdpa', **kwargs)[source]#

Loads a Hugging Face Transformers base model, optionally with pretrained weights, and returns the backbone model.

This function wraps the Hugging Face transformers library to load a model specified by model_name. It supports loading either pretrained weights or initializing from configuration only. The returned object is the model’s backbone (model.base_model), which is useful for extracting the core architecture without task-specific heads.

Parameters:
  • model_name (str) – The Hugging Face model repository identifier or local path. Examples include “bert-base-uncased”, “facebook/opt-1.3b”, or a local directory containing model files.

  • pretrained (bool) – If True, loads pretrained weights via AutoModel.from_pretrained. If False, initializes the model from configuration only via AutoConfig.from_pretrained and AutoModel.from_config.

  • attn_implementation (str, optional) – The attention backend to use. Supported values include “sdpa” (default), “eager”, “flash_attention_2”, etc., as supported by the installed version of transformers and your hardware. This is forwarded to the underlying model constructor.

  • **kwargs – Additional keyword arguments forwarded to AutoModel.from_pretrained or AutoConfig.from_pretrained. Common options include: - revision (str): Model version or branch to use. - cache_dir (str): Directory to cache downloaded models. - trust_remote_code (bool): Allow loading custom code from model repo. - torch_dtype (str or torch.dtype): Data type for model weights. - device_map (str or dict): Device placement for model parameters. - And others supported by Hugging Face Transformers.

Returns:

The base (backbone) model instance, typically accessible via model.base_model. For some architectures, this may be the model itself.

Return type:

transformers.PreTrainedModel

Raises:
  • ImportError – If the transformers library is not installed.

  • OSError – If the model or configuration cannot be found or downloaded.

  • ValueError – If invalid arguments are provided.

  • Exception – Propagates any other exceptions raised by Hugging Face Transformers.

Notes

  • The returned base_model may differ depending on the architecture. For some models, base_model is the same as the full model.

  • The availability of certain attention implementations (e.g., “flash_attention_2”) depends on your hardware, installed libraries, and the version of transformers.

  • Ensure that your environment meets the requirements for the selected attention backend.

Examples

>>> # Load a pretrained BERT model with default attention
>>> model = from_huggingface("bert-base-uncased", pretrained=True)
>>> # Initialize a model from config only, specifying a revision and device
>>> model = from_huggingface(
...     "facebook/opt-1.3b",
...     pretrained=False,
...     revision="main",
...     device_map="auto",
... )
>>> # Load a pretrained model using flash attention (if supported)
>>> model = from_huggingface(
...     "meta-llama/Llama-2-7b-hf",
...     pretrained=True,
...     attn_implementation="flash_attention_2",
... )
stable_pretraining.backbone.utils.from_timm(model_name, low_resolution=False, **kwargs)[source]#
stable_pretraining.backbone.utils.from_torchvision(model_name, low_resolution=False, **kwargs)[source]#

Load a backbone model.

If num_classes is provided, the last layer is replaced by a linear layer of output size num_classes. Otherwise, the last layer is replaced by an identity layer.

Parameters:
  • model_name (str) – Name of the backbone model. Supported models are: - Any model from torchvision.models - “Resnet9” - “ConvMixer”

  • low_resolution (bool, optional) – Whether to adapt the resolution of the model (for CIFAR typically). By default False.

  • **kwargs

    Additional keyword arguments for the model. Special handling: - in_channels (int): Number of input channels. If provided for ResNet models, the first

    conv layer will be modified to accept this many channels. Default is 3.

Returns:

The neural network model.

Return type:

torch.nn.Module

stable_pretraining.backbone.utils.get_children_modules(model: Module, parent_name: str, L: int = 1, partial_match: bool = False) List[str][source]#

Extracts unique module names matching a given parent_name and L submodules.

Parameters:
  • model – The root nn.Module.

  • parent_name – The string or path component to match (e.g., ‘blocks’).

  • L – Number of levels after the parent_name to include in the result.

  • partial_match – whether to check with == or in

Returns:

Sorted list of unique qualified module names at depth L after the parent_name.

stable_pretraining.backbone.utils.get_output_shape(model: Module, *inputs, **kwargs) Any[source]#

Infers the output shapes of a PyTorch nn.Module by forwarding fake inputs on the ‘meta’ device using FakeTensorMode.

Handles arbitrary nested output structures (lists, dicts, tuples, sets, namedtuples, dataclasses), preserving their structure but replacing torch.Tensor objects with their .shape. This function temporarily replaces the model’s parameters and buffers with fake tensors on the ‘meta’ device, converts all tensor inputs and keyword arguments to ‘meta’, and runs the forward pass under FakeTensorMode. After execution, the original parameters and buffers are restored. No real computation or memory allocation occurs.

Parameters:
  • model (torch.nn.Module) – The PyTorch module to evaluate. Must be on a real device (e.g., CPU).

  • *inputs – Positional arguments to pass to the model’s forward method. All torch.Tensor inputs are converted to ‘meta’.

  • **kwargs – Keyword arguments to pass to the model’s forward method. All torch.Tensor values are converted to ‘meta’.

Returns:

The output structure from the model’s forward pass, with all torch.Tensor objects replaced by their .shape.

Non-tensor objects are left unchanged.

Return type:

Any

Notes

  • Supports nested output structures: dict, list, tuple, set, namedtuple, and dataclasses.

  • No real memory is allocated; all tensors are on the ‘meta’ device.

  • Not thread-safe: concurrent calls may interfere with parameter/buffer swapping.

  • Requires PyTorch 1.11+ for FakeTensorMode.

  • If the model contains custom buffers or state, ensure they are handled appropriately.

  • Raises exceptions if model forward fails or if parameters/buffers cannot be swapped.

  • Non-tensor outputs are returned unchanged.

Example

shapes = get_output_shape_multi_input(model, input1, input2, key1=kwarg1) # shapes will have the same structure as the model’s output, but with torch.Size in place of tensors.

stable_pretraining.backbone.utils.register_lr_scale_hook(module, lr_scale, weight_decay=0.0)[source]#

Registers a hook that scales gradients and applies weight decay during backward pass.

Parameters:
  • module – PyTorch module/layer

  • lr_scale – Scaling factor for the learning rate (scales gradients)

  • weight_decay – L2 penalty coefficient (default: 0.0)

Returns:

The same module (for chaining)

Return type:

module

stable_pretraining.backbone.utils.set_embedding_dim(module, dim, bias=True, expected_input_shape: tuple | list | None = None, expected_output_shape: tuple | list | None = None)[source]#
stable_pretraining.backbone.utils.vit_hf(size: str = 'tiny', patch_size: int = 16, image_size: int = 224, pretrained: bool = False, use_mask_token: bool = True, **kwargs) Module[source]#

Create a Vision Transformer using HuggingFace transformers.

This provides a clean, well-maintained ViT implementation with native support for: - Masking via bool_masked_pos parameter - Learnable mask token - Easy access to CLS and patch tokens

Parameters:
  • size – Model size - “tiny”, “small”, “base”, or “large”

  • patch_size – Patch size (default: 16)

  • image_size – Input image size (default: 224)

  • pretrained – Load pretrained weights from HuggingFace Hub

  • use_mask_token – Whether to include learnable mask token (needed for iBOT)

  • **kwargs – Additional ViTConfig parameters

Returns:

HuggingFace ViTModel

Example

>>> backbone = vit_hf("tiny", use_mask_token=True)
>>> x = torch.randn(2, 3, 224, 224)
>>>
>>> # Without masking
>>> output = backbone(x)
>>> cls_token = output.last_hidden_state[:, 0, :]
>>> patch_tokens = output.last_hidden_state[:, 1:, :]
>>>
>>> # With masking (for iBOT student)
>>> masks = torch.zeros(2, 196, dtype=torch.bool)
>>> masks[:, :59] = True  # Mask 30%
>>> output = backbone(x, bool_masked_pos=masks)

stable_pretraining.backbone.vit module#

class stable_pretraining.backbone.vit.Attention(dim: int, num_heads: int = 8, qkv_bias: bool = True, attn_drop: float = 0.0, proj_drop: float = 0.0, use_rope: bool = False, use_qk_norm: bool = False, max_grid_size: int = 32)[source]#

Bases: Module

Multi-head self-attention with efficient SDPA backend.

Supports modern transformer features including Rotary Position Embeddings (RoPE) and Query-Key Normalization (QK-Norm) for improved training stability and positional generalization.

Uses F.scaled_dot_product_attention which automatically selects the optimal backend:

  • Flash Attention (when available, fastest)

  • Memory-efficient attention (xformers-style)

  • Math fallback

Architecture Features#

RoPE (Rotary Position Embedding):

Encodes relative 2D positions via complex rotations applied to Q and K. Unlike additive positional embeddings, RoPE:

  • Naturally captures relative positions

  • Generalizes to unseen sequence lengths

  • Requires no extra parameters

Enable with use_rope=True. Requires grid_size in forward().

QK-Norm (Query-Key Normalization):

Applies LayerNorm (without learnable params) to Q and K before computing attention scores. Benefits:

  • Prevents attention logit explosion in deep networks

  • Stabilizes training without extra hyperparameter tuning

  • Essential when combined with SwiGLU/LayerScale

Enable with use_qk_norm=True.

Attention Masking#

Supports flexible attention patterns via attn_mask:

  • Causal (autoregressive): torch.triu(ones, diagonal=1)

  • Bidirectional: attn_mask=None

  • Block sparse: Custom boolean masks

  • Leave-one-out: torch.eye(N) (each token ignores itself)

Mask convention: True = blocked (cannot attend), False = allowed.

param dim:

Input/output embedding dimension

param num_heads:

Number of parallel attention heads. Must divide dim.

param qkv_bias:

If True, add learnable bias to Q, K, V projections. Default True following ViT convention.

param attn_drop:

Dropout probability on attention weights. Applied only during training.

param proj_drop:

Dropout probability on output projection.

param use_rope:

Enable 2D Rotary Position Embedding. When True, position information is encoded via rotation in attention rather than additive embeddings. Requires grid_size parameter in forward().

param use_qk_norm:

Enable Query-Key normalization. Applies LayerNorm (without learnable parameters) to Q and K tensors before attention. Recommended for deep networks or when using SwiGLU.

param max_grid_size:

Maximum spatial grid size for RoPE frequency cache. Only used when use_rope=True. Set to largest expected grid dimension.

Example:

# Standard attention
attn = Attention(dim=768, num_heads=12)
out = attn(x)  # [B, N, 768]

# With RoPE for vision (requires grid_size)
attn = Attention(dim=768, num_heads=12, use_rope=True)
out = attn(x, grid_size=(14, 14))

# With QK-Norm for training stability
attn = Attention(dim=768, num_heads=12, use_qk_norm=True)
out = attn(x)

# Causal attention (autoregressive)
N = x.shape[1]
causal_mask = torch.triu(torch.ones(N, N, dtype=torch.bool), diagonal=1)
out = attn(x, attn_mask=causal_mask)

# NEPA-style: RoPE + QK-Norm + causal
attn = Attention(dim=768, num_heads=12, use_rope=True, use_qk_norm=True)
causal_mask = torch.triu(
    torch.ones(N, N, dtype=torch.bool, device=x.device), diagonal=1
)
out = attn(x, attn_mask=causal_mask, grid_size=(14, 14))

Note

When use_rope=True, do NOT add positional embeddings to input tokens. RoPE encodes positions internally via Q/K rotation.

References

  • RoPE: Su et al., “RoFormer: Enhanced Transformer with Rotary Position Embedding” (2021)

  • QK-Norm: Henry et al., “Query-Key Normalization for Transformers” (2020)

  • Flash Attention: Dao et al., “FlashAttention: Fast and Memory-Efficient Exact Attention” (2022)

extra_repr() str[source]#

Return the extra representation of the module.

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

forward(x: Tensor, attn_mask: Tensor | None = None, grid_size: Tuple[int, int] | None = None) Tensor[source]#

Compute multi-head self-attention.

Parameters:
  • x – Input tensor of shape [B, N, D] where B is batch size, N is sequence length, and D is embedding dimension.

  • attn_mask

    Optional attention mask. Supported shapes:

    • [N, N]: Same mask for all batches and heads

    • [B, N, N]: Per-batch mask, broadcast over heads

    • [B, H, N, N]: Full per-batch, per-head mask

    Values: True = blocked (cannot attend), False = allowed.

    Common patterns:

    • Causal: torch.triu(torch.ones(N, N, dtype=torch.bool), diagonal=1)

    • Leave-one-out: torch.eye(N, dtype=torch.bool)

  • grid_size – Spatial grid dimensions as (height, width). Required when use_rope=True. Used to compute 2D rotary position embeddings. For a 224x224 image with patch_size=16, use grid_size=(14, 14).

Returns:

Output tensor of shape [B, N, D]

Raises:

ValueError – If use_rope=True but grid_size is None

class stable_pretraining.backbone.vit.CrossAttention(dim: int, context_dim: int | None = None, num_heads: int = 8, qkv_bias: bool = True, attn_drop: float = 0.0, proj_drop: float = 0.0)[source]#

Bases: Module

Multi-head cross-attention with efficient SDPA backend.

Queries attend to key-value pairs from a separate context sequence. Supports attention masking to block specific query-key interactions.

Parameters:
  • dim – Query dimension

  • context_dim – Context dimension (defaults to dim)

  • num_heads – Number of attention heads

  • qkv_bias – Add bias to projections

  • attn_drop – Attention dropout rate

  • proj_drop – Output projection dropout rate

Example:

cross_attn = CrossAttention(dim=768, context_dim=1024, num_heads=12)

# Standard cross-attention
out = cross_attn(queries, context)  # [B, N, 768]

# Masked cross-attention: block certain query-context pairs
# mask[i, j] = True means query i cannot attend to context j
mask = torch.zeros(N, M, dtype=torch.bool)
mask[:, :10] = True  # block attention to first 10 context tokens
out = cross_attn(queries, context, attn_mask=mask)
forward(x: Tensor, context: Tensor, attn_mask: Tensor | None = None) Tensor[source]#

Forward pass.

Parameters:
  • x – Query tensor [B, N, D]

  • context – Key-value tensor [B, M, context_dim]

  • attn_mask – Cross-attention mask. Can be one of: - [N, M]: Same mask for all batches and heads - [B, N, M]: Per-batch mask, broadcast over heads - [B, H, N, M]: Full per-batch, per-head mask Mask values: True = blocked (cannot attend), False = allowed.

Returns:

Output tensor [B, N, D]

class stable_pretraining.backbone.vit.FlexibleTransformer(input_dim: int = 768, hidden_dim: int = 384, output_dim: int = 768, num_patches: int = 196, depth: int = 4, num_heads: int = 6, mlp_ratio: float = 4.0, self_attn: bool = True, cross_attn: bool = True, use_adaln: bool = True, pos_embed_type: Literal['sincos_1d', 'sincos_2d', 'learned'] = 'sincos_2d', grid_size: int | tuple[int, int] | None = None, drop_path_rate: float = 0.0, attn_drop: float = 0.0, proj_drop: float = 0.0, zero_init_output: bool = True, num_prefix_tokens: int = 1, num_registers: int = 0, add_mask_token: bool = False)[source]#

Bases: Module

Flexible transformer supporting multiple architectures.

Unified backbone for: - MAE decoder: self_attn=True, cross_attn=False, use_adaln=False - IJEPA predictor: self_attn=True, cross_attn=True, use_adaln=False - DiT / Flow: self_attn=True, cross_attn=True/False, use_adaln=True - MaskGIT: self_attn=True, cross_attn=False, use_adaln=True, add_mask_token=True - Lightweight predictor: self_attn=True, cross_attn=False, use_adaln=False, num_registers>0 - Leave-one-out prediction: self_attn=True, cross_attn=False with diagonal attn_mask

Parameters:
  • input_dim – Input embedding dimension (from encoder)

  • hidden_dim – Internal transformer dimension

  • output_dim – Output dimension

  • num_patches – Total number of patches (for positional embeddings)

  • depth – Number of transformer blocks

  • num_heads – Number of attention heads

  • mlp_ratio – MLP hidden dim multiplier

  • self_attn – Enable self-attention in blocks

  • cross_attn – Enable cross-attention in blocks

  • use_adaln – Enable AdaLN-Zero conditioning

  • pos_embed_type – ‘sincos_1d’, ‘sincos_2d’, or ‘learned’

  • grid_size – Grid size for 2D positional embeddings

  • drop_path_rate – Stochastic depth rate (linearly increases through layers)

  • attn_drop – Attention dropout rate

  • proj_drop – Projection dropout rate

  • zero_init_output – Zero-initialize output projection

  • num_prefix_tokens – Number of prefix tokens (e.g., CLS token) expected in input. These are tokens whose content comes from the encoder but need special positional embeddings.

  • num_registers – Number of learnable register tokens to prepend internally. Unlike prefix tokens, registers are fully learnable (both content and position) and are prepended automatically—callers don’t include them in input.

  • add_mask_token – Enable learnable [MASK] token for masked prediction. When enabled, use context_mask and/or query_mask in forward() to replace tokens at specified positions with the [MASK] token.

Example:

# MAE decoder
decoder = FlexibleTransformer(
    768,
    512,
    768,
    196,
    depth=8,
    self_attn=True,
    cross_attn=False,
    use_adaln=False,
)
out = decoder(context, queries, context_idx, query_idx)

# IJEPA predictor
predictor = FlexibleTransformer(
    768,
    384,
    768,
    196,
    depth=6,
    self_attn=True,
    cross_attn=False,
    add_mask_token=True,
    use_adaln=False,
)
out = predictor(context, queries, context_idx, query_idx)

# DiT-style flow matching
flow = FlexibleTransformer(
    768,
    384,
    768,
    196,
    depth=12,
    self_attn=True,
    cross_attn=False,
    use_adaln=True,
)
out = flow(context, queries, context_idx, query_idx, t=timesteps)

# MaskGIT-style: variable number of masks per sample
maskgit = FlexibleTransformer(
    768,
    512,
    768,
    196,
    depth=8,
    self_attn=True,
    cross_attn=False,
    use_adaln=True,
    add_mask_token=True,
)
context_mask = torch.rand(B, num_patches) < mask_ratio
out = maskgit(
    context=all_patches,
    queries=all_patches[:, :0],
    context_idx=torch.arange(196).expand(B, -1),
    query_idx=torch.empty(B, 0, dtype=torch.long),
    context_mask=context_mask,
    t=timesteps,
    return_all=True,
)

# Leave-one-out prediction: each token predicted from all others
predictor = FlexibleTransformer(
    768,
    384,
    768,
    196,
    depth=4,
    self_attn=True,
    cross_attn=False,
    use_adaln=False,
)
# Diagonal mask: each token cannot attend to itself
T = x.shape[1]
attn_mask = torch.eye(T, dtype=torch.bool, device=x.device)
out = predictor(
    context=x,
    queries=x[:, :0],  # empty queries
    context_idx=torch.arange(T).expand(B, -1),
    query_idx=torch.empty(B, 0, dtype=torch.long),
    attn_mask=attn_mask,  # [T, T] bool, True = blocked
    return_all=True,
)  # out[:, t] is predicted from x[:, ≠t]

# Lightweight predictor with register tokens
predictor = FlexibleTransformer(
    768,
    384,
    768,
    196,
    depth=4,
    num_heads=6,
    self_attn=True,
    cross_attn=False,
    use_adaln=False,
    num_registers=4,
    num_prefix_tokens=0,
)
out, registers = predictor(
    context=encoder_output,
    queries=encoder_output[:, :0],
    context_idx=ids_keep,
    query_idx=torch.empty(B, 0, dtype=torch.long),
    return_all=True,
    return_registers=True,
)
forward(context: Tensor, queries: Tensor = None, context_idx: Tensor = None, query_idx: Tensor = None, t: Tensor | None = None, num_prefix: int | None = None, return_all: bool = False, return_registers: bool = False, context_mask: Tensor | None = None, query_mask: Tensor | None = None, attn_mask: Tensor | None = None) Tensor | tuple[Tensor, Tensor][source]#

Forward pass.

Parameters:
  • context – Context token embeddings [B, N_ctx, input_dim]

  • queries – Query token embeddings [B, N_qry, input_dim]

  • context_idx – Patch indices for context tokens [B, N_ctx]

  • query_idx – Patch indices for query tokens [B, N_qry]

  • t – Timesteps for conditioning [B] (required if use_adaln=True)

  • num_prefix – Override for number of prefix tokens in context

  • return_all – If True and using joint attention (cross_attn=False), return all tokens unshuffled to original position order. Output shape: [B, N_ctx + N_qry, output_dim]. Ignored for cross-attention modes.

  • return_registers – If True and num_registers > 0, also return register token outputs as a second tensor. Returns tuple of (main_output, register_output) where register_output is [B, num_registers, output_dim].

  • context_mask – Boolean mask indicating which context tokens to replace with [MASK] token [B, N_ctx]. True = replace with mask. Each sample can have a different number of True values. Requires add_mask_token=True.

  • query_mask – Boolean mask indicating which query tokens to replace with [MASK] token [B, N_qry]. True = replace with mask. Each sample can have a different number of True values. Requires add_mask_token=True.

  • attn_mask – Attention mask for self-attention [T, T] or [B, T, T]. True = blocked (cannot attend), False = allowed. For leave-one-out prediction, use torch.eye(T, dtype=torch.bool). Only applies to joint attention mode (cross_attn=False). The mask is automatically expanded to account for registers.

Returns:

Output embeddings. Shape depends on mode: - cross_attn=True: [B, N_qry, output_dim] - cross_attn=False, return_all=False: [B, N_qry, output_dim] - cross_attn=False, return_all=True: [B, N_ctx + N_qry, output_dim] If return_registers=True, returns tuple (output, registers) where registers is [B, num_registers, output_dim].

class stable_pretraining.backbone.vit.MAEDecoder(embed_dim: int = 768, decoder_embed_dim: int = 512, output_dim: int = 768, num_patches: int = 196, depth: int = 4, num_heads: int = 16, mlp_ratio: float = 4.0, pos_embed_type: Literal['sincos_1d', 'sincos_2d', 'learned'] = 'sincos_2d', grid_size: int | None = None, drop_path_rate: float = 0.0)[source]#

Bases: Module

MAE-style Vision Transformer Decoder using FlexibleTransformer.

Implements the decoder component of Masked Autoencoders (MAE) [1] for self-supervised visual representation learning. The decoder reconstructs masked patches from visible patch embeddings using joint self-attention, where visible tokens and learnable mask tokens attend to each other. The decoder is intentionally lightweight compared to the encoder, as MAE demonstrates that a shallow decoder is sufficient for pixel reconstruction while keeping the encoder focused on learning semantic representations. Architecture Overview ——————— 1. Input projection: Maps encoder embeddings (embed_dim) to decoder

dimension (decoder_embed_dim)

  1. Mask token expansion: Learnable mask tokens are placed at masked positions

  2. Positional encoding: Adds position information to all tokens

  3. Transformer blocks: Joint self-attention over visible + mask tokens

5. Output projection: Maps to output_dim (typically patch_size² × channels) :param embed_dim: Embedding dimension from the encoder. This is the input dimension

of visible tokens passed to the decoder.

Parameters:
  • decoder_embed_dim (int, default=512) – Internal hidden dimension of the decoder transformer blocks. Typically smaller than embed_dim for efficiency.

  • output_dim (int, default=768) – Output dimension per token. For pixel reconstruction, this should be patch_size ** 2 * in_channels (e.g., 16×16×3 = 768 for RGB).

  • num_patches (int, default=196) – Total number of patches T in the image (e.g., 14×14 = 196 for 224×224 images with patch_size=16).

  • depth (int, default=4) – Number of transformer blocks in the decoder. MAE typically uses fewer blocks than the encoder (e.g., 4-8 vs 12-24).

  • num_heads (int, default=16) – Number of attention heads in multi-head self-attention.

  • mlp_ratio (float, default=4.0) – Expansion ratio for the MLP hidden dimension relative to decoder_embed_dim.

  • pos_embed_type ({'sincos_1d', 'sincos_2d', 'learned'}, default='sincos_2d') – Type of positional embedding: - ‘sincos_2d’: Fixed 2D sinusoidal (recommended for images) - ‘sincos_1d’: Fixed 1D sinusoidal - ‘learned’: Learnable positional embeddings

  • grid_size (int, optional) – Spatial grid size for 2D positional embeddings. If None, inferred as int(sqrt(num_patches)). Required for non-square grids.

  • drop_path_rate (float, default=0.0) – Stochastic depth rate for regularization during training.

  • Attributes

  • ----------

  • mask_token (nn.Parameter) – Learnable token of shape (1, 1, embed_dim) used to represent masked positions. Initialized with truncated normal (std=0.02).

  • transformer (FlexibleTransformer) – Core transformer module handling attention and projections.

  • Notes

  • -----

  • MAE (- The mask convention follows)

  • positions (- The decoder receives visible tokens and reconstructs masked)

  • efficiency (- For)

  • default (>>> mask_ratio = 0.75 # MAE)

  • References

  • ----------

  • He (.. [1]) – CVPR 2022. https://arxiv.org/abs/2111.06377

  • K. – CVPR 2022. https://arxiv.org/abs/2111.06377

  • Learners." (et al. "Masked Autoencoders Are Scalable Vision) – CVPR 2022. https://arxiv.org/abs/2111.06377

  • Examples

  • --------

  • Encoder** (**Basic Usage with MAE)

  • torch (>>> import)

  • nn (>>> import torch.nn as)

  • >>>

  • ViT-Base (>>> # Configuration matching)

  • B (>>>)

  • 4 (T =)

  • size (196 # batch)

  • (14x14) (num_patches)

  • dimension (>>> embed_dim = 768 # encoder)

  • default

  • >>>

  • decoder (>>> # Initialize)

  • MAEDecoder( (>>> decoder =)

  • embed_dim=embed_dim (...)

:param : :param … decoder_embed_dim=512: :param : :param … output_dim=16 * 16 * 3: :param # patch_size² × channels = 768: :param … num_patches=T: :param : :param … depth=4: :param : :param … num_heads=16: :param : :param … ): :param >>>: :param >>> # Simulate encoder output (visible tokens only): :param >>> N_vis = int(T * (1 - mask_ratio)) # 49 visible patches: :param >>> visible_tokens = torch.randn(B: :param N_vis: :param embed_dim): :param >>>: :param >>> # Create random mask (0=visible: :param 1=masked): :param >>> mask = torch.zeros(B: :param T): :param >>> for i in range(B): :param … masked_indices = torch.randperm(T)[: :type … masked_indices = torch.randperm(T)[: T - N_vis] :param … mask[i: :param masked_indices] = 1: :param >>>: :param >>> # Decode - predict masked patches only: :param >>> pred_masked = decoder(visible_tokens: :param mask: :param output_masked_only=True): :param >>> print(pred_masked.shape) # [B: :param N_mask: :param output_dim]: :param torch.Size([4: :param 147: :param 768]): :param **Full Sequence Reconstruction**: :param >>> # Get predictions for ALL positions (for visualization): :param >>> pred_full = decoder(visible_tokens: :param mask: :param output_masked_only=False): :param >>> print(pred_full.shape) # [B: :param T: :param output_dim]: :param torch.Size([4: :param 196: :param 768]): :param **Using Full Sequence Input**: :param If you have the full sequence with mask tokens already inserted: :param >>> full_sequence = torch.randn(B: :param T: :param embed_dim) # [B: :param 196: :param 768]: :param >>> pred = decoder(full_sequence: :param mask: :param output_masked_only=True): :param >>> print(pred.shape): :param torch.Size([4: :param 147: :param 768]): :param **Integration with MAE Training Loop**: :param >>> # Typical MAE training step (pseudocode): :param >>> def mae_forward(encoder: :param decoder: :param images: :param mask_ratio=0.75): :param … # Patchify and mask: :param … patches = patchify(images) # [B: :param T: :param patch_dim]: :param … mask = random_mask(B: :param T: :param mask_ratio) # [B: :param T]: :param 0=keep: :param 1=mask: :param …: :param … # Encode visible patches only: :param … visible_patches = patches[~mask.bool()].reshape(B: :param -1: :param patch_dim): :param … latent = encoder(visible_patches) # [B: :param N_vis: :param embed_dim]: :param …: :param … # Decode to predict masked patches: :param … pred = decoder(: :param … latent: :param mask: :param output_masked_only=True: :param … ) # [B: :param N_mask: :param output_dim]: :param …: :param … # Reconstruction loss on masked patches only: :param … target = patches[mask.bool()].reshape(B: :param -1: :param patch_dim): :param … loss = F.mse_loss(pred: :param target): :param … return loss: :param **Custom Configuration for ViT-Large**: :param >>> decoder_large = MAEDecoder(: :param … embed_dim=1024: :param # ViT-L encoder dim: :param … decoder_embed_dim=512: :param # Keep decoder lightweight: :param … output_dim=768: :param # 16×16×3 pixels: :param … num_patches=256: :param # 16×16 patches for 256×256 images: :param … depth=8: :param # Slightly deeper: :param … num_heads=16: :param : :param … pos_embed_type=”sincos_2d”: :param : :param … drop_path_rate=0.1: :param # Regularization: :param … ): :param See Also: :param ——–: :param FlexibleTransformer: :type FlexibleTransformer: Core transformer implementation used internally.

forward(x: Tensor, mask: Tensor, ids_keep: Tensor | None = None, output_masked_only: bool = False) Tensor[source]#

Forward pass.

Parameters:
  • x – Visible tokens [B, N_vis, D] or full sequence [B, T, D]

  • mask – Binary mask [B, T], 0=kept, 1=masked

  • ids_keep – Indices of kept (visible) patches (B, N_keep)

  • output_masked_only – If True, return [B, N_mask, D]. If False, return [B, T, D].

Returns:

Predictions

class stable_pretraining.backbone.vit.MaskedEncoder(model_or_model_name: str | Module = 'vit_base_patch16_224', masking: PatchMasking | None = None, pretrained: bool = False, img_size: int | Tuple[int, int] | None = None, patch_size: int | Tuple[int, int] | None = None, dynamic_img_size: bool = False, norm_layer: Module | None = None)[source]#

Bases: Module

Vision Transformer encoder with optional masking support.

Wraps a timm ViT model and adds flexible masking via PatchMasking. Handles all ViT internals: patch embedding, positional embeddings, prefix tokens (CLS, registers), and transformer blocks. :param model_or_model_name: timm model name string or pre-instantiated nn.Module :param masking: PatchMasking instance. If None, no masking is applied. :param pretrained: Load pretrained weights (only when model_or_model_name is str) :param img_size: Override default image size :param patch_size: Override default patch size (will reinitialize patch_embed) :param dynamic_img_size: Enable dynamic image size support with pos_embed interpolation Example:

from spt.backbone import PatchMasking, MaskedEncoder

masking = PatchMasking(mask_ratio=0.75, block_size=4)
encoder = MaskedEncoder(
    model_or_model_name="vit_base_patch16_224",
    masking=masking,
    pretrained=True,
)
images = torch.randn(4, 3, 224, 224)
output = encoder(images)
print(output.encoded.shape)  # (4, 1 + 49, 768) with 75% masking
print(output.mask.shape)  # (4, 196)
print(output.ids_keep.shape)  # (4, 49)
extra_repr() str[source]#

Return the extra representation of the module.

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

forward(images: Tensor) MaskedEncoderOutput[source]#

Encode images with optional masking.

Parameters:

images – Input images (B, C, H, W)

Returns:

MaskedEncoderOutput with encoded tokens and mask info

forward_features(images: Tensor) Tensor[source]#

Encode without masking (for inference).

class stable_pretraining.backbone.vit.MaskedEncoderOutput(encoded: Tensor = None, mask: Tensor = None, ids_keep: Tensor = None, grid_size: Tuple[int, int] = None)[source]#

Bases: ModelOutput

Output from MaskedEncoder forward pass.

Variables:
  • encoded – Encoded token representations (B, num_prefix + N_visible, D)

  • mask – Binary mask where 1 = masked, 0 = visible (B, N_patches)

  • ids_keep – Indices of visible patches (B, N_visible)

  • grid_size – Patch grid dimensions (height, width)

encoded: Tensor = None#
grid_size: Tuple[int, int] = None#
ids_keep: Tensor = None#
mask: Tensor = None#
class stable_pretraining.backbone.vit.PositionalEncoding2D(embed_dim: int, grid_size: Tuple[int, int], pos_type: Literal['learnable', 'sinusoidal', 'rope', 'none'] = 'learnable', num_prefix_tokens: int = 1, learnable: bool | None = None)[source]#

Bases: Module

Flexible 2D positional encoding for vision transformers.

forward(x: Tensor, grid_size: Tuple[int, int] | None = None) Tensor[source]#

Apply positional encoding.

Parameters:
  • x – [B, num_prefix + num_patches, D]

  • grid_size – (H, W) if different from default (for dynamic size)

Returns:

x with positional encoding applied

class stable_pretraining.backbone.vit.QKNorm(head_dim: int)[source]#

Bases: Module

Query-Key Normalization for attention stabilization.

Applies LayerNorm (without learnable parameters) independently to query and key tensors before computing attention scores. This simple technique dramatically improves training stability in deep transformers. Why QK-Norm Works —————– In deep transformers, attention logits (Q·Kᵀ) can grow unboundedly large, causing: - Gradient explosion: Large logits → extreme softmax → tiny gradients - Attention collapse: All mass on single token - Training instability: Requires very small learning rates QK-Norm bounds the attention logits by normalizing Q and K to unit variance: - Attention logits become bounded: |q·k| ||q|| ||k|| = O(√d) - Gradients remain stable throughout training - Enables larger learning rates and faster convergence Implementation Details ———————- - Uses LayerNorm without learnable parameters (γ=1, β=0) - Normalizes per-head: applied to [..., head_dim] dimension - Zero computational overhead in modern frameworks (fused with attention) :param head_dim: Dimension per attention head. Each head is normalized

independently to preserve multi-head diversity.

Example::

# In attention forward pass qk_norm = QKNorm(head_dim=64) # q, k shape: [B, num_heads, seq_len, head_dim] q, k = qk_norm(q, k) # Now safe to compute attention attn = (q @ k.transpose(-2, -1)) * scale

Example integration with Attention::
class Attention(nn.Module):
def __init__(self, dim, num_heads, use_qk_norm=True):

… if use_qk_norm:

self.qk_norm = QKNorm(dim // num_heads)

def forward(self, x):

q, k, v = self.qkv(x).chunk(3, dim=-1) if self.use_qk_norm:

q, k = self.qk_norm(q, k)

Note

QK-Norm is especially important when combined with: - SwiGLU: Gated activations can amplify hidden states - LayerScale: Small initial residual scale needs stable attention - Deep networks: Logit growth compounds with depth Without QK-Norm, these combinations often fail to train or require extensive hyperparameter tuning.

References

  • Henry et al., “Query-Key Normalization for Transformers” (EMNLP 2020)

  • Dehghani et al., “Scaling Vision Transformers to 22B Parameters” (2023)

  • Wortsman et al., “Small-scale proxies for large-scale Transformer training” (2023).

extra_repr() str[source]#

Return the extra representation of the module.

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

forward(q: Tensor, k: Tensor) tuple[Tensor, Tensor][source]#

Normalize query and key tensors independently.

Parameters:
  • q – Query tensor of shape [B, num_heads, seq_len, head_dim] or any shape with last dimension = head_dim

  • k – Key tensor of same shape as q

Returns:

Tuple of (normalized_q, normalized_k) with same shapes

Note

Normalization is applied to the last dimension (head_dim). Each head is normalized independently, preserving multi-head representation diversity.

class stable_pretraining.backbone.vit.SwiGLU(in_features: int, hidden_features: int, out_features: int | None = None, bias: bool = False, drop: float = 0.0)[source]#

Bases: Module

SwiGLU: Gated Linear Unit with Swish activation.

A parameter-efficient gated activation that combines the benefits of gating mechanisms with the smooth, non-monotonic Swish activation. Empirically improves transformer performance over standard GeLU MLPs. Architecture ———— Standard MLP:

x → Linear → GeLU → Linear → out
Parameters: 2 * d * h
SwiGLU::
x → Linear(W₁) → SiLU ─┐

├─ element-wise multiply → Linear(W₃) → out

x → Linear(W₂) ────────┘ Parameters: 3 * d * h’ (where h’ = 2h/3 to match param count)

The hidden dimension is scaled to 2/3 * hidden_features so that total parameter count matches a standard 2-layer MLP: 3 * d * (2h/3) = 2 * d * h Performance Benefits ——————– - Better gradient flow: Gating provides multiplicative paths - Smoother optimization: SiLU (Swish) is smooth and non-monotonic - Quality: Consistently outperforms GeLU in language and vision models :param in_features: Input dimension :param hidden_features: Nominal hidden dimension. Actual hidden size is

int(2 * hidden_features / 3) to maintain parameter parity with standard MLPs.

Parameters:
  • out_features – Output dimension. Defaults to in_features.

  • bias – If True, use bias in linear layers. Default False following LLaMA/PaLM convention for better training stability.

  • drop – Dropout probability applied after gating.

Example::

# Replace standard MLP in transformer # Old: mlp = Mlp(768, 3072, 768) # New: mlp = SwiGLU(768, 3072, 768) x = torch.randn(4, 196, 768) out = mlp(x) # [4, 196, 768] # Parameter count comparison standard_mlp = nn.Sequential(

nn.Linear(768, 3072), nn.GELU(), nn.Linear(3072, 768)

) swiglu = SwiGLU(768, 3072, 768) print(sum(p.numel() for p in standard_mlp.parameters())) # 4,722,432 print(sum(p.numel() for p in swiglu.parameters())) # 4,722,432 (same!)

Note

For best results, combine SwiGLU with: - LayerScale: Stabilizes residual connections - QK-Norm: Prevents attention explosion - RoPE: Better positional generalization This combination is used in LLaMA, PaLM, and modern vision transformers.

References

  • Shazeer, “GLU Variants Improve Transformer” (2020)

  • Touvron et al., “LLaMA: Open and Efficient Foundation Language Models” (2023)

extra_repr() str[source]#

Return the extra representation of the module.

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

forward(x: Tensor) Tensor[source]#

Apply SwiGLU transformation.

class stable_pretraining.backbone.vit.TransformerBlock(dim: int, num_heads: int, mlp_ratio: float = 4.0, self_attn: bool = True, cross_attn: bool = False, use_adaln: bool = False, use_rope: bool = False, use_qk_norm: bool = False, mlp_type: Literal['gelu', 'swiglu']='gelu', use_layer_scale: bool = False, layer_scale_init: float = 1e-05, drop_path: float = 0.0, attn_drop: float = 0.0, proj_drop: float = 0.0, max_grid_size: int = 32, act_layer: type = <class 'torch.nn.modules.activation.GELU'>)[source]#

Bases: Module

Unified transformer block supporting multiple architectures.

Configurable for various modern transformer designs:

Architecture

RoPE

QK-Norm

MLP

LayerScale

AdaLN

Standard ViT DINOv2 / Modern NEPA DiT / Flow

✗ ✓ ✓ ✗

✗ ✓ ✓ ✗

gelu swiglu swiglu gelu

✗ ✓ ✓ ✗

✗ ✗ ✗ ✓

Attention Modes#

Mode 1: Self-Attention Only (self_attn=True, cross_attn=False)

Standard encoder block. Used for NEPA, ViT encoder, etc.

Mode 2: Cross-Attention Only (self_attn=False, cross_attn=True)

Queries attend to context only. Lightweight decoder.

Mode 3: Full Decoder (self_attn=True, cross_attn=True)

Self-attention on queries, then cross-attention to context.

Modern Components#

RoPE (use_rope=True):

2D Rotary Position Embedding. Encodes positions via Q/K rotation. Requires grid_size in forward(). Don’t use additive pos_embed.

QK-Norm (use_qk_norm=True):

Normalizes Q and K before attention. Stabilizes deep networks.

SwiGLU (mlp_type='swiglu'):

Gated MLP with SiLU activation. Better than GeLU empirically.

LayerScale (use_layer_scale=True):

Learnable per-channel scaling on residuals. Stabilizes training. Initialize near zero (e.g., 1e-5) for identity-like initialization.

param dim:

Hidden dimension

param num_heads:

Number of attention heads

param mlp_ratio:

MLP hidden dim = dim * mlp_ratio

param self_attn:

Enable self-attention

param cross_attn:

Enable cross-attention

param use_adaln:

Enable AdaLN-Zero conditioning (for diffusion/flow)

param use_rope:

Enable 2D Rotary Position Embedding in attention

param use_qk_norm:

Enable Query-Key normalization in attention

param mlp_type:

MLP activation type: ‘gelu’ or ‘swiglu’

param use_layer_scale:

Enable LayerScale on residual connections

param layer_scale_init:

Initial value for LayerScale (default: 1e-5)

param drop_path:

Stochastic depth rate

param attn_drop:

Attention dropout rate

param proj_drop:

Projection dropout rate

param max_grid_size:

Maximum grid size for RoPE cache

Example:

# Standard ViT block
block = TransformerBlock(dim=768, num_heads=12)

# Modern ViT (DINOv2-style)
block = TransformerBlock(
    dim=768,
    num_heads=12,
    use_rope=True,
    use_qk_norm=True,
    mlp_type="swiglu",
    use_layer_scale=True,
)
out = block(x, grid_size=(14, 14))

# NEPA block (modern + causal)
causal_mask = torch.triu(torch.ones(N, N, dtype=torch.bool), diagonal=1)
out = block(x, grid_size=(14, 14), attn_mask=causal_mask)

# DiT block (with conditioning)
block = TransformerBlock(dim=768, num_heads=12, use_adaln=True)
out = block(x, cond=time_emb)
forward(x: Tensor, context: Tensor | None = None, cond: Tensor | None = None, attn_mask: Tensor | None = None, cross_attn_mask: Tensor | None = None, grid_size: Tuple[int, int] | None = None) Tensor[source]#

Forward pass.

Parameters:
  • x – Input tensor [B, N, D]

  • context – Context for cross-attention [B, M, D]

  • cond – Conditioning tensor [B, D] (required if use_adaln=True)

  • attn_mask – Self-attention mask. True = blocked.

  • cross_attn_mask – Cross-attention mask. True = blocked.

  • grid_size – (grid_h, grid_w) for RoPE. Required if use_rope=True.

Returns:

Output tensor [B, N, D]

stable_pretraining.backbone.vit.modulate(x: Tensor, shift: Tensor, scale: Tensor) Tensor[source]#

Apply AdaLN modulation: x * (1 + scale) + shift.

Module contents#