stable_pretraining.backbone package#
Submodules#
stable_pretraining.backbone.aggregator module#
Modular tensor aggregation module for feeding multi-scale/multi-layer features to MLPs.
Commonly used for: - SSL linear probes using multiple transformer layers - Multi-scale feature fusion - Combining features from different network stages
- class stable_pretraining.backbone.aggregator.TensorAggregator(input_spec: str | List[str] | Dict[str, str], adaptive_pool_size: int = 1)[source]#
Bases:
ModuleAggregates multi-dimensional tensors into 2D format for MLP input.
Pure aggregation module with NO trainable parameters. Handles various input formats and aggregation strategies.
- Parameters:
input_spec – Specification of input format and aggregation modes: - str: Single aggregation mode for all tensors (e.g., “mean”) - List[str]: Per-tensor aggregation modes for list inputs - Dict[str, str]: Per-key aggregation modes for dict inputs
adaptive_pool_size – Output size for adaptive pooling (default: 1)
- Aggregation Modes:
“mean”: Spatial/temporal mean pooling
“max”: Spatial/temporal max pooling
“cls”: Take first token (for transformers with [CLS] token)
“flatten”: Flatten all dimensions after batch
“adaptive”: Adaptive average pooling to fixed size
Examples
>>> # Single tensor with mean pooling >>> agg = TensorAggregator("mean") >>> x = torch.randn(4, 768, 14, 14) >>> out = agg(x) # Shape: (4, 768)
>>> # SSL: Last 4 transformer layers with CLS token >>> agg = TensorAggregator(["cls", "cls", "cls", "cls"]) >>> layers = [torch.randn(4, 197, 768) for _ in range(4)] >>> out = agg(layers) # Shape: (4, 3072) # 768 * 4
>>> # Multi-scale features >>> agg = TensorAggregator({"layer1": "cls", "layer2": "mean", "conv": "mean"}) >>> out = agg( ... { ... "layer1": torch.randn(4, 197, 768), ... "layer2": torch.randn(4, 197, 768), ... "conv": torch.randn(4, 512, 14, 14), ... } ... ) # Shape: (4, 2048)
- compute_output_dim(input_shapes: tuple | List[tuple] | Dict[str, tuple]) int[source]#
Compute the output dimension given input shapes.
- Parameters:
input_shapes – Shape(s) of input tensor(s) (excluding batch dim)
- Returns:
Total output features
Examples
>>> agg = TensorAggregator(["cls", "mean"]) >>> agg.compute_output_dim([(197, 768), (197, 768)]) 1536
>>> agg = TensorAggregator({"l1": "cls", "conv": "mean"}) >>> agg.compute_output_dim({"l1": (197, 768), "conv": (512, 14, 14)}) 1280
stable_pretraining.backbone.convmixer module#
- class stable_pretraining.backbone.convmixer.ConvMixer(in_channels=3, num_classes=10, dim=64, depth=6, kernel_size=9, patch_size=7)[source]#
Bases:
ModuleConvMixer model.
A simple and efficient convolutional architecture that operates directly on patches.
- Parameters:
in_channels (int, optional) – Number of input channels. Defaults to 3.
num_classes (int, optional) – Number of output classes. Defaults to 10.
dim (int, optional) – Hidden dimension size. Defaults to 64.
depth (int, optional) – Number of ConvMixer blocks. Defaults to 6.
kernel_size (int, optional) – Kernel size for depthwise convolution. Defaults to 9.
patch_size (int, optional) – Patch embedding size. Defaults to 7.
Note
Introduced in [Trockman and Kolter, 2022].
- forward(xb)[source]#
Forward pass through the ConvMixer model.
- Parameters:
xb (torch.Tensor) – Input tensor of shape (batch_size, in_channels, height, width).
- Returns:
Output logits of shape (batch_size, num_classes).
- Return type:
stable_pretraining.backbone.mae module#
- class stable_pretraining.backbone.mae.MaskedAutoencoderViT(img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=24, num_heads=16, decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, mlp_ratio=4.0, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>, norm_pix_loss=False)[source]#
Bases:
ModuleMasked Autoencoder with VisionTransformer backbone.
- forward(imgs, mask_ratio=0.75)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- patchify(imgs)[source]#
Convert images to patches.
- Parameters:
imgs – (N, 3, H, W)
- Returns:
(N, L, patch_size**2 *3)
- Return type:
x
- random_masking(x, mask_ratio)[source]#
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
- Parameters:
x – [N, L, D], sequence
mask_ratio – ratio of patches to mask
- Returns:
masked sequence mask: binary mask ids_restore: indices to restore original order
- Return type:
x_masked
- stable_pretraining.backbone.mae.get_1d_sincos_pos_embed_from_grid(embed_dim, pos)[source]#
Get 1D sinusoidal positional embedding from grid.
- Parameters:
embed_dim – output dimension for each position
pos – a list of positions to be encoded: size (M,)
- Returns:
(M, D)
- Return type:
out
- stable_pretraining.backbone.mae.get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False)[source]#
Get 2D sinusoidal positional embedding.
- Parameters:
embed_dim – embedding dimension
grid_size – int of the grid height and width
cls_token – whether to include class token
- Returns:
[grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
- Return type:
pos_embed
- stable_pretraining.backbone.mae.vit_base_patch16(**kwargs)#
- stable_pretraining.backbone.mae.vit_huge_patch14(**kwargs)#
- stable_pretraining.backbone.mae.vit_large_patch16(**kwargs)#
stable_pretraining.backbone.mlp module#
- class stable_pretraining.backbone.mlp.MLP(in_channels: int, hidden_channels: list[int], norm_layer: str = None, activation_layer=<class 'torch.nn.modules.activation.ReLU'>, inplace: bool = None, bias: bool = True, dropout: float = 0.0)[source]#
Bases:
SequentialThis block implements the multi-layer perceptron (MLP) module.
- Parameters:
in_channels (int) – Number of channels of the input
hidden_channels (List[int]) – List of the hidden channel dimensions
norm_layer (Callable[..., torch.nn.Module], optional) – Norm layer that will be stacked on top of the linear layer. If
Nonethis layer won’t be used. Default:Noneactivation_layer (Callable[..., torch.nn.Module], optional) – Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the linear layer. If
Nonethis layer won’t be used. Default:torch.nn.ReLUinplace (bool, optional) – Parameter for the activation layer, which can optionally do the operation in-place. Default is
None, which uses the respective default values of theactivation_layerand Dropout layer.bias (bool) – Whether to use bias in the linear layer. Default
Truedropout (float) – The probability for the dropout layer. Default: 0.0
stable_pretraining.backbone.patch_masking module#
Patch masking strategies for masked image modeling.
- class stable_pretraining.backbone.patch_masking.MaskingOutput(visible: Tensor, mask: Tensor, ids_restore: Tensor, ids_keep: Tensor)[source]#
Bases:
objectOutput from patch masking operation.
- Variables:
visible – Visible patch embeddings (B, N_keep, D)
mask – Binary mask where 1 = masked, 0 = visible (B, N)
ids_restore – Indices to restore original order (B, N)
ids_keep – Indices of kept (visible) patches (B, N_keep)
- class stable_pretraining.backbone.patch_masking.PatchMasking(mask_ratio: float = 0.75, block_size: int = 1, crop_ratio: float = 0.0, crop_aspect_ratio: tuple[float, float] = (0.75, 1.33))[source]#
Bases:
ModuleFlexible patch masking module for masked image modeling.
Supports three masking strategies that are selected stochastically:
Random: Uniformly random patch selection (when block_size=1)
Block: Square blocks of adjacent patches (when block_size > 1)
Crop: Rectangular crop region, remaining patches masked (when crop_ratio > 0)
Strategy selection per sample:
With probability
crop_ratio, use crop maskingOtherwise, if
block_size > 1, use block maskingOtherwise, use random masking
- Parameters:
mask_ratio – Fraction of patches to mask, in [0, 1)
block_size – Size of square blocks for block masking (1 = random masking)
crop_ratio – Probability of using crop masking vs block/random
crop_aspect_ratio – (min, max) aspect ratio range for crop regions
Example:
masking = PatchMasking(mask_ratio=0.75, block_size=4) output = masking(patch_embeddings, grid_h=14, grid_w=14) visible_patches = output.visible # (B, N_keep, D) mask = output.mask # (B, N), 1=masked, 0=visible ids_keep = output.ids_keep # (B, N_keep)
- extra_repr() str[source]#
Return the extra representation of the module.
To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.
- forward(x: Tensor, grid_h: int, grid_w: int) MaskingOutput[source]#
Apply masking to patch embeddings.
- Parameters:
x – Patch embeddings of shape (B, N, D) where N = grid_h * grid_w
grid_h – Height of the patch grid
grid_w – Width of the patch grid
- Returns:
MaskingOutput containing visible patches and mask information
- Raises:
ValueError – If x.shape[1] != grid_h * grid_w
ValueError – If input tensor has wrong number of dimensions
stable_pretraining.backbone.pos_embed module#
Positional embedding utilities for vision transformers.
- stable_pretraining.backbone.pos_embed.get_1d_sincos_pos_embed(embed_dim: int, length: int, cls_token: bool = False) Tensor[source]#
Generate 1D sinusoidal positional embeddings.
- Parameters:
embed_dim – Embedding dimension
length – Sequence length (number of positions)
cls_token – If True, prepend a zero embedding for CLS token
- Returns:
Positional embeddings of shape (length, embed_dim) or (length + 1, embed_dim) if cls_token=True
- stable_pretraining.backbone.pos_embed.get_2d_sincos_pos_embed(embed_dim: int, grid_size: int | tuple[int, int], cls_token: bool = False) Tensor[source]#
Generate 2D sinusoidal positional embeddings for image patches.
- Parameters:
embed_dim – Embedding dimension (must be divisible by 4)
grid_size – Grid height/width as int (square) or (height, width) tuple
cls_token – If True, prepend a zero embedding for CLS token
- Returns:
Positional embeddings of shape (H*W, embed_dim) or (H*W + 1, embed_dim) if cls_token=True
- stable_pretraining.backbone.pos_embed.get_sincos_pos_embed(embed_dim: int, num_patches: int, mode: Literal['1d', '2d'] = '1d', grid_size: int | tuple[int, int] | None = None, cls_token: bool = False) Tensor[source]#
Unified interface for generating sinusoidal positional embeddings.
- Parameters:
embed_dim – Embedding dimension
num_patches – Total number of patches (used for 1d mode)
mode – Embedding type - ‘1d’ for sequence, ‘2d’ for image grid
grid_size – Required for ‘2d’ mode
cls_token – If True, prepend a zero embedding for CLS token
- Returns:
Positional embeddings tensor
- stable_pretraining.backbone.pos_embed.get_timestep_embed(t: Tensor, dim: int, max_period: int = 10000) Tensor[source]#
Generate sinusoidal embeddings for continuous timesteps.
Unlike positional embeddings for sequences, this embeds scalar timestep values. Used for diffusion/flow matching time conditioning. :param t: Timestep values (B,) or (B, 1), typically in [0, 1] :param dim: Embedding dimension :param max_period: Maximum period for frequency scaling :return: Timestep embeddings of shape (B, dim)
- stable_pretraining.backbone.pos_embed.interpolate_pos_embed(pos_embed: Tensor, src_size: tuple[int, int], tgt_size: tuple[int, int], num_prefix_tokens: int = 0, mode: str = 'bicubic') Tensor[source]#
Interpolate positional embeddings to a new grid size.
- Parameters:
pos_embed – Original positional embeddings of shape (1, num_prefix + src_h*src_w, embed_dim) or (num_prefix + src_h*src_w, embed_dim)
src_size – Source grid size as (height, width)
tgt_size – Target grid size as (height, width)
num_prefix_tokens – Number of prefix tokens (CLS, registers) to preserve
mode – Interpolation mode (‘nearest’, ‘bilinear’, ‘bicubic’, ‘area’)
- Returns:
Interpolated positional embeddings
Example:
old_pos = model.pos_embed # (1, 197, 768) = 1 + 14*14 new_pos = interpolate_pos_embed( old_pos, src_size=(14, 14), tgt_size=(16, 16), num_prefix_tokens=1 ) # (1, 257, 768) = 1 + 16*16
stable_pretraining.backbone.probe module#
- class stable_pretraining.backbone.probe.AutoLinearClassifier(name, embedding_dim, num_classes, pooling=None, weight_decay=[0], lr_scaling=[1], normalization=['none', 'norm', 'bn'], dropout=[0, 0.5], label_smoothing=[0, 1])[source]#
Bases:
ModuleLinear using either CLS token or mean pooling with configurable normalization layer.
- Parameters:
embedding_dim (int) – Dimensionality of the input embeddings.
num_classes (int) – Number of output classes.
pooling (str) – Pooling strategy, either ‘cls’ or ‘mean’.
norm_layer (callable or None) – Normalization layer class (e.g., torch.nn.LayerNorm, torch.nn.BatchNorm1d), or None for no normalization. Should accept a single argument: normalized_shape or num_features.
- norm#
Instantiated normalization layer, or None.
- Type:
nn.Module or None
- fc#
Linear layer mapping pooled representation to class logits.
- Type:
nn.Linear
- Forward Args:
- x (torch.Tensor): Input tensor of shape (N, T, D) or (N, D).
If 3D, pooling and normalization are applied. If 2D, input is used directly (no pooling or normalization).
- Returns:
Output logits of shape (N, num_classes).
- Return type:
Example
>>> probe = LinearProbe( ... embedding_dim=128, ... num_classes=10, ... pooling="mean", ... norm_layer=torch.nn.LayerNorm, ... ) >>> x = torch.randn(32, 20, 128) >>> logits = probe(x) # shape: (32, 10) >>> x2 = torch.randn(32, 128) >>> logits2 = probe(x2) # shape: (32, 10)
- forward(x, y=None, pl_module=None)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class stable_pretraining.backbone.probe.AutoTuneMLP(in_features: int, out_features: int, hidden_features: List[int] | List[List[int]], name: str, loss_fn: Callable, additional_weight_decay: float | List[float] = [0], lr_scaling: float | List[float] = [1], normalization: str | List[str] = ['none'], dropout: float | List[float] = [0], activation: str | List[str] = ['relu'])[source]#
Bases:
ModuleAutomatically creates multiple MLP variants with different hyperparameter combinations.
This module creates a grid of MLPs with different configurations (dropout, normalization, learning rates, architectures, etc.) to enable parallel hyperparameter tuning.
- Parameters:
in_features – Number of input features
out_features – Number of output features
hidden_features – Architecture specification. Can be: - List[int]: Single architecture, e.g., [256, 128] - List[List[int]]: Multiple architectures, e.g., [[256, 128], [512, 256, 128]] - []: Empty list for linear model (no hidden layers)
name – Base name for this AutoTuneMLP instance
loss_fn – Loss function to compute loss
additional_weight_decay – List of weight decay values to try
lr_scaling – List of learning rate scaling factors to try
normalization – List of normalization types [‘none’, ‘norm’, ‘bn’]
dropout – List of dropout rates to try
activation – List of activation functions [‘relu’, ‘leaky_relu’, ‘tanh’]
Examples
>>> # Single architecture >>> model = AutoTuneMLP(128, 10, [256, 128], "clf", nn.CrossEntropyLoss())
>>> # Multiple architectures >>> model = AutoTuneMLP( ... 128, 10, [[256], [256, 128], [512, 256]], "clf", nn.CrossEntropyLoss() ... )
>>> # Linear model (no hidden layers) >>> model = AutoTuneMLP(128, 10, [], "linear_clf", nn.CrossEntropyLoss())
- forward(x: Tensor, y: Tensor | None = None) Dict[str, Tensor][source]#
Forward pass through all MLP variants.
- Parameters:
x – Input tensor of shape (batch_size, in_features)
y – Optional target tensor for loss computation
- Returns:
Dictionary with predictions and losses for each variant Format: {‘pred/{variant_id}’: tensor, ‘loss/{variant_id}’: tensor}
- get_best_variant(metric_dict: Dict[str, float], lower_is_better: bool = True) str[source]#
Get the best performing variant based on metrics.
- Parameters:
metric_dict – Dictionary mapping variant_id to metric values
lower_is_better – If True, lower metric is better (e.g., loss). If False, higher is better (e.g., accuracy)
- Returns:
ID of the best performing variant
- get_variant(key: str) Module[source]#
Get a specific MLP variant by key.
- Parameters:
key – Variant ID
- Returns:
The MLP module
- Raises:
KeyError – If key doesn’t exist
- class stable_pretraining.backbone.probe.LinearProbe(embedding_dim, num_classes, pooling='cls', norm_layer=None)[source]#
Bases:
ModuleLinear using either CLS token or mean pooling with configurable normalization layer.
- Parameters:
embedding_dim (int) – Dimensionality of the input embeddings.
num_classes (int) – Number of output classes.
pooling (str) – Pooling strategy, either ‘cls’ or ‘mean’.
norm_layer (callable or None) – Normalization layer class (e.g., torch.nn.LayerNorm, torch.nn.BatchNorm1d), or None for no normalization. Should accept a single argument: normalized_shape or num_features.
- norm#
Instantiated normalization layer, or None.
- Type:
nn.Module or None
- fc#
Linear layer mapping pooled representation to class logits.
- Type:
nn.Linear
- Forward Args:
- x (torch.Tensor): Input tensor of shape (N, T, D) or (N, D).
If 3D, pooling and normalization are applied. If 2D, input is used directly (no pooling or normalization).
- Returns:
Output logits of shape (N, num_classes).
- Return type:
Example
>>> probe = LinearProbe( ... embedding_dim=128, ... num_classes=10, ... pooling="mean", ... norm_layer=torch.nn.LayerNorm, ... ) >>> x = torch.randn(32, 20, 128) >>> logits = probe(x) # shape: (32, 10) >>> x2 = torch.randn(32, 128) >>> logits2 = probe(x2) # shape: (32, 10)
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class stable_pretraining.backbone.probe.MultiHeadAttentiveProbe(embedding_dim: int, num_classes: int, num_heads: int = 4)[source]#
Bases:
ModuleA multi-head attentive probe for sequence representations.
This module applies multiple attention heads to a sequence of embeddings, pools the sequence into a fixed-size representation per head, concatenates the results, and projects to a set of output classes.
- Parameters:
- ln#
Layer normalization applied to the input.
- Type:
- attn_vectors#
Learnable attention vectors for each head, shape (num_heads, embedding_dim).
- Type:
torch.nn.Parameter
- fc#
Final linear layer mapping concatenated head outputs to class logits.
- Type:
- Forward Args:
- x (torch.Tensor): Input tensor of shape (N, T, D), where
N = batch size, T = sequence length, D = embedding_dim.
- Returns:
Output logits of shape (N, num_classes).
- Return type:
Example
>>> probe = MultiHeadAttentiveProbe( ... embedding_dim=128, num_classes=10, num_heads=4 ... ) >>> x = torch.randn(32, 20, 128) # batch of 32, sequence length 20 >>> logits = probe(x) # shape: (32, 10)
- forward(x: Tensor)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
stable_pretraining.backbone.resnet9 module#
- class stable_pretraining.backbone.resnet9.MLP(in_channels: int, hidden_channels: list[int], norm_layer: str = None, activation_layer=<class 'torch.nn.modules.activation.ReLU'>, inplace: bool = None, bias: bool = True, dropout: float = 0.0)[source]#
Bases:
SequentialThis block implements the multi-layer perceptron (MLP) module.
- Parameters:
in_channels (int) – Number of channels of the input
hidden_channels (List[int]) – List of the hidden channel dimensions
norm_layer (Callable[..., torch.nn.Module], optional) – Norm layer that will be stacked on top of the linear layer. If
Nonethis layer won’t be used. Default:Noneactivation_layer (Callable[..., torch.nn.Module], optional) – Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the linear layer. If
Nonethis layer won’t be used. Default:torch.nn.ReLUinplace (bool, optional) – Parameter for the activation layer, which can optionally do the operation in-place. Default is
None, which uses the respective default values of theactivation_layerand Dropout layer.bias (bool) – Whether to use bias in the linear layer. Default
Truedropout (float) – The probability for the dropout layer. Default: 0.0
- class stable_pretraining.backbone.resnet9.ResidualBlock(in_channels, out_channels, kernel_size, padding, stride)[source]#
Bases:
ModuleA residual block as defined by He et al.
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class stable_pretraining.backbone.resnet9.Resnet9(num_classes, num_channels, *args, **kwargs)[source]#
Bases:
ModuleA Residual network.
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
stable_pretraining.backbone.utils module#
- class stable_pretraining.backbone.utils.EfficientMaskedTimmViT(vit: Module)[source]#
Bases:
ModuleOptimized Vision Transformer wrapper that efficiently handles NaN patches.
This module is designed to work with timm ViT models and provides: - Per-sample NaN masking (different NaN patterns per image in batch) - Fast path for same masking pattern across batch - Support for class tokens (cls_token), distillation tokens (dist_token), and register tokens - Compatibility with various timm ViT architectures (vit_*, deit_*, beit_*, etc.) - Minimal overhead when no masking is present
Key Optimizations: - Early exit when no NaN patches detected - Simpler indexing for same masking patterns - Cached batch indices for repeated operations - Zero-copy operations where possible
- Parameters:
vit – A timm Vision Transformer model instance
- Raises:
ValueError – If samples have different numbers of NaN patches
ValueError – If all patches are NaN
RuntimeError – If the model structure is incompatible
Example
>>> import timm >>> vit = timm.create_model( ... "vit_base_patch16_224", pretrained=False, reg_tokens=4 ... ) >>> masked_vit = EfficientMaskedTimmViT(vit) >>> >>> # Create input with some NaN patches >>> x = torch.randn(4, 3, 224, 224) >>> output = masked_vit(x)
- Performance:
Same pattern masking: ~0-5% overhead vs different patterns
No masking: <2% overhead vs original model
50% masking: ~1.5x speedup
90% masking: ~2.5-3x speedup
Note
All samples in a batch must have the same NUMBER of NaN patches, but the LOCATION of NaN patches can differ per sample.
Register tokens (DINOv2 style) do NOT receive positional embeddings.
- clear_cache()[source]#
Clear the cached batch indices.
Useful if you want to free memory after processing different batch sizes. The cache will be rebuilt as needed during forward passes.
- forward(x: Tensor) Tensor[source]#
Forward pass through the masked ViT.
This method implements an optimized forward pass with the following features: - Early exit for inputs without NaN patches (fast path) - Optimized indexing for same masking patterns across batch - Per-sample masking support with advanced indexing - Automatic NaN replacement for partial NaN patches - Support for register tokens (DINOv2 style)
- Parameters:
x – Input tensor, either:
images (- Raw) – shape (B, C, H, W)
Pre-patchified (-) – shape (B, N, D) where N is number of patches
- Returns:
Model output (logits if head exists, features otherwise)
- Return type:
- Raises:
ValueError – If samples have different numbers of NaN patches
ValueError – If all patches are NaN
- Performance Notes:
No NaN patches: Uses fast path with <2% overhead
Same pattern: Optimized indexing, ~0-5% overhead vs different patterns
Different patterns: Uses advanced indexing, ~10-35% slower at high masking
- class stable_pretraining.backbone.utils.EmbeddingOutput(last_hidden_state: Any, hidden_states: dict[str, Tensor])[source]#
Bases:
objectHuggingFace-style output container for model embeddings.
The final output from the backbone model.
- Type:
Any
Dictionary mapping layer names to their intermediate outputs.
- Type:
- class stable_pretraining.backbone.utils.EvalOnly(backbone: Module)[source]#
Bases:
ModuleWrapper that forces a module to remain in evaluation mode.
- forward(*args, **kwargs)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class stable_pretraining.backbone.utils.FeaturesConcat(agg: callable, names: str | Iterable[str] = None)[source]#
Bases:
ModuleAggregates and concatenates features from a dictionary input, then classifies.
- Parameters:
names (List[str]) – Keys to extract from the input dictionary. if not given then we aggregate everything from dict/list
- forward(inputs: dict | Iterable)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class stable_pretraining.backbone.utils.HiddenStateExtractor(backbone: Module, module_names: list[str])[source]#
Bases:
ModuleWrapper that captures intermediate embeddings from specified layers.
Returns outputs in HuggingFace Transformers style with last_hidden_state for the final backbone output and hidden_states for intermediate layers.
- Parameters:
backbone – The neural network module to wrap.
module_names – List of module names to capture (e.g., [‘layer1’, ‘encoder.block1’]). Supports nested modules using dot notation.
- Returns:
last_hidden_state: Final backbone output
hidden_states: Dict mapping module names to their outputs
- Return type:
EmbeddingOutput with
Example
>>> model = ReturnEmbedding( ... torchvision.models.swin_v2_s(), ... ["features.0", "features.2", "features.4"], ... ) >>> output = model(images) >>> output.last_hidden_state # final output >>> output.hidden_states["features.2"] # intermediate layer
- Raises:
ValueError – If any module name is not found in the backbone.
- forward(*args, **kwargs) EmbeddingOutput[source]#
Run forward pass and return embeddings in HuggingFace style.
- class stable_pretraining.backbone.utils.TeacherStudentWrapper(student: Module, warm_init: bool = True, base_ema_coefficient: float = 0.996, final_ema_coefficient: float = 1)[source]#
Bases:
ModuleBackbone wrapper that implements teacher-student distillation via EMA.
This is a wrapper for backbones that creates a teacher model as an exponential moving average (EMA) of the student model. It should be passed as the backbone to stable_pretraining.Module and accessed via forward_student() and forward_teacher() methods in your custom forward function.
The teacher model is updated by taking a running average of the student’s parameters and buffers. When ema_coefficient == 0.0, the teacher and student are literally the same object, saving memory but forward passes through the teacher will not produce any gradients.
- Usage example:
backbone = ResNet18() wrapped_backbone = TeacherStudentWrapper(backbone) module = ssl.Module(
backbone=wrapped_backbone, projector=projector, forward=forward_with_teacher_student, …
)
- Parameters:
student (torch.nn.Module) – The student model whose parameters will be tracked.
warm_init (bool, optional) – If True, performs an initialization step to match the student’s parameters immediately. Default is True.
base_ema_coefficient (float, optional) – EMA decay factor at the start of training. This value will be updated following a cosine schedule. Should be in [0, 1]. A value of 0.0 means the teacher is fully updated to the student’s parameters on every step, while a value of 1.0 means the teacher remains unchanged. Default is 0.996.
final_ema_coefficient (float, optional) – EMA decay factor at the end of training. Default is 1.
- forward(*args, **kwargs)[source]#
Forward pass through either the student or teacher network.
You can choose which model to run in the default forward. Commonly the teacher is evaluated, so we default to that.
- forward_student(*args, **kwargs)[source]#
Forward pass through the student network. Gradients will flow normally.
- forward_teacher(*args, **kwargs)[source]#
Forward pass through the teacher network.
By default, the teacher network does not require grad. If ema_coefficient == 0, then teacher==student, so we wrap in torch.no_grad() to ensure no gradients flow.
- update_ema_coefficient(epoch: int, total_epochs: int)[source]#
Update the EMA coefficient following a cosine schedule.
- The EMA coefficient is updated following a cosine schedule:
ema_coefficient = final_ema_coefficient - 0.5 * (final_ema_coefficient - base_ema_coefficient) * (1 + cos(epoch / total_epochs * pi))
- update_teacher()[source]#
Perform one EMA update step on the teacher’s parameters.
- The update rule is:
teacher_param = ema_coefficient * teacher_param + (1 - ema_coefficient) * student_param
This is done in a no_grad context to ensure the teacher’s parameters do not accumulate gradients, but the student remains fully trainable.
Everything is updated, including buffers (e.g. batch norm running averages).
- stable_pretraining.backbone.utils.from_huggingface(model_name, pretrained, attn_implementation='sdpa', **kwargs)[source]#
Loads a Hugging Face Transformers base model, optionally with pretrained weights, and returns the backbone model.
This function wraps the Hugging Face transformers library to load a model specified by model_name. It supports loading either pretrained weights or initializing from configuration only. The returned object is the model’s backbone (model.base_model), which is useful for extracting the core architecture without task-specific heads.
- Parameters:
model_name (str) – The Hugging Face model repository identifier or local path. Examples include “bert-base-uncased”, “facebook/opt-1.3b”, or a local directory containing model files.
pretrained (bool) – If True, loads pretrained weights via AutoModel.from_pretrained. If False, initializes the model from configuration only via AutoConfig.from_pretrained and AutoModel.from_config.
attn_implementation (str, optional) – The attention backend to use. Supported values include “sdpa” (default), “eager”, “flash_attention_2”, etc., as supported by the installed version of transformers and your hardware. This is forwarded to the underlying model constructor.
**kwargs – Additional keyword arguments forwarded to AutoModel.from_pretrained or AutoConfig.from_pretrained. Common options include: - revision (str): Model version or branch to use. - cache_dir (str): Directory to cache downloaded models. - trust_remote_code (bool): Allow loading custom code from model repo. - torch_dtype (str or torch.dtype): Data type for model weights. - device_map (str or dict): Device placement for model parameters. - And others supported by Hugging Face Transformers.
- Returns:
The base (backbone) model instance, typically accessible via model.base_model. For some architectures, this may be the model itself.
- Return type:
transformers.PreTrainedModel
- Raises:
ImportError – If the transformers library is not installed.
OSError – If the model or configuration cannot be found or downloaded.
ValueError – If invalid arguments are provided.
Exception – Propagates any other exceptions raised by Hugging Face Transformers.
Notes
The returned base_model may differ depending on the architecture. For some models, base_model is the same as the full model.
The availability of certain attention implementations (e.g., “flash_attention_2”) depends on your hardware, installed libraries, and the version of transformers.
Ensure that your environment meets the requirements for the selected attention backend.
Examples
>>> # Load a pretrained BERT model with default attention >>> model = from_huggingface("bert-base-uncased", pretrained=True) >>> # Initialize a model from config only, specifying a revision and device >>> model = from_huggingface( ... "facebook/opt-1.3b", ... pretrained=False, ... revision="main", ... device_map="auto", ... ) >>> # Load a pretrained model using flash attention (if supported) >>> model = from_huggingface( ... "meta-llama/Llama-2-7b-hf", ... pretrained=True, ... attn_implementation="flash_attention_2", ... )
- stable_pretraining.backbone.utils.from_torchvision(model_name, low_resolution=False, **kwargs)[source]#
Load a backbone model.
If num_classes is provided, the last layer is replaced by a linear layer of output size num_classes. Otherwise, the last layer is replaced by an identity layer.
- Parameters:
model_name (str) – Name of the backbone model. Supported models are: - Any model from torchvision.models - “Resnet9” - “ConvMixer”
low_resolution (bool, optional) – Whether to adapt the resolution of the model (for CIFAR typically). By default False.
**kwargs –
Additional keyword arguments for the model. Special handling: - in_channels (int): Number of input channels. If provided for ResNet models, the first
conv layer will be modified to accept this many channels. Default is 3.
- Returns:
The neural network model.
- Return type:
- stable_pretraining.backbone.utils.get_children_modules(model: Module, parent_name: str, L: int = 1, partial_match: bool = False) List[str][source]#
Extracts unique module names matching a given parent_name and L submodules.
- Parameters:
model – The root nn.Module.
parent_name – The string or path component to match (e.g., ‘blocks’).
L – Number of levels after the parent_name to include in the result.
partial_match – whether to check with == or in
- Returns:
Sorted list of unique qualified module names at depth L after the parent_name.
- stable_pretraining.backbone.utils.get_output_shape(model: Module, *inputs, **kwargs) Any[source]#
Infers the output shapes of a PyTorch nn.Module by forwarding fake inputs on the ‘meta’ device using FakeTensorMode.
Handles arbitrary nested output structures (lists, dicts, tuples, sets, namedtuples, dataclasses), preserving their structure but replacing torch.Tensor objects with their .shape. This function temporarily replaces the model’s parameters and buffers with fake tensors on the ‘meta’ device, converts all tensor inputs and keyword arguments to ‘meta’, and runs the forward pass under FakeTensorMode. After execution, the original parameters and buffers are restored. No real computation or memory allocation occurs.
- Parameters:
model (torch.nn.Module) – The PyTorch module to evaluate. Must be on a real device (e.g., CPU).
*inputs – Positional arguments to pass to the model’s forward method. All torch.Tensor inputs are converted to ‘meta’.
**kwargs – Keyword arguments to pass to the model’s forward method. All torch.Tensor values are converted to ‘meta’.
- Returns:
- The output structure from the model’s forward pass, with all torch.Tensor objects replaced by their .shape.
Non-tensor objects are left unchanged.
- Return type:
Any
Notes
Supports nested output structures: dict, list, tuple, set, namedtuple, and dataclasses.
No real memory is allocated; all tensors are on the ‘meta’ device.
Not thread-safe: concurrent calls may interfere with parameter/buffer swapping.
Requires PyTorch 1.11+ for FakeTensorMode.
If the model contains custom buffers or state, ensure they are handled appropriately.
Raises exceptions if model forward fails or if parameters/buffers cannot be swapped.
Non-tensor outputs are returned unchanged.
Example
shapes = get_output_shape_multi_input(model, input1, input2, key1=kwarg1) # shapes will have the same structure as the model’s output, but with torch.Size in place of tensors.
- stable_pretraining.backbone.utils.register_lr_scale_hook(module, lr_scale, weight_decay=0.0)[source]#
Registers a hook that scales gradients and applies weight decay during backward pass.
- Parameters:
module – PyTorch module/layer
lr_scale – Scaling factor for the learning rate (scales gradients)
weight_decay – L2 penalty coefficient (default: 0.0)
- Returns:
The same module (for chaining)
- Return type:
- stable_pretraining.backbone.utils.set_embedding_dim(module, dim, bias=True, expected_input_shape: tuple | list | None = None, expected_output_shape: tuple | list | None = None)[source]#
- stable_pretraining.backbone.utils.vit_hf(size: str = 'tiny', patch_size: int = 16, image_size: int = 224, pretrained: bool = False, use_mask_token: bool = True, **kwargs) Module[source]#
Create a Vision Transformer using HuggingFace transformers.
This provides a clean, well-maintained ViT implementation with native support for: - Masking via bool_masked_pos parameter - Learnable mask token - Easy access to CLS and patch tokens
- Parameters:
size – Model size - “tiny”, “small”, “base”, or “large”
patch_size – Patch size (default: 16)
image_size – Input image size (default: 224)
pretrained – Load pretrained weights from HuggingFace Hub
use_mask_token – Whether to include learnable mask token (needed for iBOT)
**kwargs – Additional ViTConfig parameters
- Returns:
HuggingFace ViTModel
Example
>>> backbone = vit_hf("tiny", use_mask_token=True) >>> x = torch.randn(2, 3, 224, 224) >>> >>> # Without masking >>> output = backbone(x) >>> cls_token = output.last_hidden_state[:, 0, :] >>> patch_tokens = output.last_hidden_state[:, 1:, :] >>> >>> # With masking (for iBOT student) >>> masks = torch.zeros(2, 196, dtype=torch.bool) >>> masks[:, :59] = True # Mask 30% >>> output = backbone(x, bool_masked_pos=masks)
stable_pretraining.backbone.vit module#
- class stable_pretraining.backbone.vit.Attention(dim: int, num_heads: int = 8, qkv_bias: bool = True, attn_drop: float = 0.0, proj_drop: float = 0.0)[source]#
Bases:
ModuleMulti-head self-attention with efficient SDPA backend.
Uses F.scaled_dot_product_attention which automatically selects: - Flash Attention (when available, fastest) - Memory-efficient attention (xformers-style) - Math fallback :param dim: Input dimension :param num_heads: Number of attention heads :param qkv_bias: Add bias to QKV projection :param attn_drop: Attention dropout rate :param proj_drop: Output projection dropout rate
- class stable_pretraining.backbone.vit.CrossAttention(dim: int, context_dim: int | None = None, num_heads: int = 8, qkv_bias: bool = True, attn_drop: float = 0.0, proj_drop: float = 0.0)[source]#
Bases:
ModuleMulti-head cross-attention with efficient SDPA backend.
Queries attend to key-value pairs from a separate context sequence. :param dim: Query dimension :param context_dim: Context dimension (defaults to dim) :param num_heads: Number of attention heads :param qkv_bias: Add bias to projections :param attn_drop: Attention dropout rate :param proj_drop: Output projection dropout rate
- class stable_pretraining.backbone.vit.FlexibleTransformer(input_dim: int = 768, hidden_dim: int = 384, output_dim: int = 768, num_patches: int = 196, depth: int = 4, num_heads: int = 6, mlp_ratio: float = 4.0, self_attn: bool = True, cross_attn: bool = True, use_adaln: bool = True, pos_embed_type: Literal['sincos_1d', 'sincos_2d', 'learned'] = 'sincos_2d', grid_size: int | tuple[int, int] | None = None, drop_path_rate: float = 0.0, attn_drop: float = 0.0, proj_drop: float = 0.0, zero_init_output: bool = True, num_prefix_tokens: int = 1, add_mask_token: bool = False)[source]#
Bases:
ModuleFlexible transformer supporting multiple architectures.
Unified backbone for: - MAE decoder: self_attn=True, cross_attn=False, use_adaln=False - IJEPA predictor: self_attn=True, cross_attn=True, use_adaln=False - DiT / Flow: self_attn=True, cross_attn=True/False, use_adaln=True - MaskGIT: self_attn=True, cross_attn=False, use_adaln=True, add_mask_token=True :param input_dim: Input embedding dimension (from encoder) :param hidden_dim: Internal transformer dimension :param output_dim: Output dimension :param num_patches: Total number of patches (for positional embeddings) :param depth: Number of transformer blocks :param num_heads: Number of attention heads :param mlp_ratio: MLP hidden dim multiplier :param self_attn: Enable self-attention in blocks :param cross_attn: Enable cross-attention in blocks :param use_adaln: Enable AdaLN-Zero conditioning :param pos_embed_type: ‘sincos_1d’, ‘sincos_2d’, or ‘learned’ :param grid_size: Grid size for 2D positional embeddings :param drop_path_rate: Stochastic depth rate (linearly increases through layers) :param attn_drop: Attention dropout rate :param proj_drop: Projection dropout rate :param zero_init_output: Zero-initialize output projection :param num_prefix_tokens: Number of prefix tokens (e.g., CLS token) :param add_mask_token: Enable learnable [MASK] token for masked prediction.
When enabled, use context_mask and/or query_mask in forward() to replace tokens at specified positions with the [MASK] token.
- Example::
# MAE decoder decoder = FlexibleTransformer(
768, 512, 768, 196, depth=8, self_attn=True, cross_attn=False, use_adaln=False,
) out = decoder(context, queries, context_idx, query_idx) # IJEPA predictor predictor = FlexibleTransformer(
768, 384, 768, 196, depth=6, self_attn=True, cross_attn=True, use_adaln=False,
) out = predictor(context, queries, context_idx, query_idx) # DiT-style flow matching flow = FlexibleTransformer(
768, 384, 768, 196, depth=12, self_attn=True, cross_attn=False, use_adaln=True,
) out = flow(context, queries, context_idx, query_idx, t=timesteps) # MaskGIT-style: variable number of masks per sample maskgit = FlexibleTransformer(
768, 512, 768, 196, depth=8, self_attn=True, cross_attn=False, use_adaln=True, add_mask_token=True,
) # Each sample can have different number of masked positions # context_mask[b, i] = True means replace context[b, i] with [MASK] context_mask = torch.rand(B, num_patches) < mask_ratio # Variable per sample! out = maskgit(
context=all_patches, # [B, 196, D] queries=all_patches[:, :0], # [B, 0, D] empty context_idx=torch.arange(196).expand(B, -1), # [B, 196] query_idx=torch.empty(B, 0, dtype=torch.long), context_mask=context_mask, # [B, 196] bool, variable True count t=timesteps, return_all=True,
) # Returns [B, 196, output_dim] # BERT-style MLM: mask random tokens in sequence bert = FlexibleTransformer(
768, 768, 768, 512, depth=12, self_attn=True, cross_attn=False, use_adaln=False, add_mask_token=True,
) # Random 15% masking, different positions per sample context_mask = torch.rand(B, seq_len) < 0.15 out = bert(
context=token_embeddings, queries=token_embeddings[:, :0], context_idx=position_ids, query_idx=torch.empty(B, 0, dtype=torch.long), context_mask=context_mask, return_all=True,
)
- forward(context: Tensor, queries: Tensor, context_idx: Tensor, query_idx: Tensor, t: Tensor | None = None, num_prefix: int | None = None, return_all: bool = False, context_mask: Tensor | None = None, query_mask: Tensor | None = None) Tensor[source]#
Forward pass.
- Parameters:
context – Context token embeddings [B, N_ctx, input_dim]
queries – Query token embeddings [B, N_qry, input_dim]
context_idx – Patch indices for context tokens [B, N_ctx]
query_idx – Patch indices for query tokens [B, N_qry]
t – Timesteps for conditioning [B] (required if use_adaln=True)
num_prefix – Override for number of prefix tokens in context
return_all – If True and using joint attention (cross_attn=False), return all tokens unshuffled to original position order. Output shape: [B, N_ctx + N_qry, output_dim]. Ignored for cross-attention modes.
context_mask – Boolean mask indicating which context tokens to replace with [MASK] token [B, N_ctx]. True = replace with mask. Each sample can have a different number of True values. Requires add_mask_token=True.
query_mask – Boolean mask indicating which query tokens to replace with [MASK] token [B, N_qry]. True = replace with mask. Each sample can have a different number of True values. Requires add_mask_token=True.
- Returns:
Output embeddings. Shape depends on mode: - cross_attn=True: [B, N_qry, output_dim] - cross_attn=False, return_all=False: [B, N_qry, output_dim] - cross_attn=False, return_all=True: [B, N_ctx + N_qry, output_dim].
- class stable_pretraining.backbone.vit.MAEDecoder(embed_dim: int = 768, decoder_embed_dim: int = 512, output_dim: int = 768, num_patches: int = 196, depth: int = 4, num_heads: int = 16, mlp_ratio: float = 4.0, pos_embed_type: Literal['sincos_1d', 'sincos_2d', 'learned'] = 'sincos_2d', grid_size: int | None = None, drop_path_rate: float = 0.0)[source]#
Bases:
ModuleMAE-style Vision Transformer Decoder using FlexibleTransformer.
Implements the decoder component of Masked Autoencoders (MAE) [1]_ for self-supervised visual representation learning. The decoder reconstructs masked patches from visible patch embeddings using joint self-attention, where visible tokens and learnable mask tokens attend to each other. The decoder is intentionally lightweight compared to the encoder, as MAE demonstrates that a shallow decoder is sufficient for pixel reconstruction while keeping the encoder focused on learning semantic representations. Architecture Overview ——————— 1. Input projection: Maps encoder embeddings (embed_dim) to decoder
dimension (decoder_embed_dim)
Mask token expansion: Learnable mask tokens are placed at masked positions
Positional encoding: Adds position information to all tokens
Transformer blocks: Joint self-attention over visible + mask tokens
5. Output projection: Maps to output_dim (typically patch_size² × channels) :param embed_dim: Embedding dimension from the encoder. This is the input dimension
of visible tokens passed to the decoder.
- Parameters:
decoder_embed_dim (int, default=512) – Internal hidden dimension of the decoder transformer blocks. Typically smaller than embed_dim for efficiency.
output_dim (int, default=768) – Output dimension per token. For pixel reconstruction, this should be
patch_size ** 2 * in_channels(e.g., 16×16×3 = 768 for RGB).num_patches (int, default=196) – Total number of patches T in the image (e.g., 14×14 = 196 for 224×224 images with patch_size=16).
depth (int, default=4) – Number of transformer blocks in the decoder. MAE typically uses fewer blocks than the encoder (e.g., 4-8 vs 12-24).
num_heads (int, default=16) – Number of attention heads in multi-head self-attention.
mlp_ratio (float, default=4.0) – Expansion ratio for the MLP hidden dimension relative to decoder_embed_dim.
pos_embed_type ({'sincos_1d', 'sincos_2d', 'learned'}, default='sincos_2d') – Type of positional embedding: - ‘sincos_2d’: Fixed 2D sinusoidal (recommended for images) - ‘sincos_1d’: Fixed 1D sinusoidal - ‘learned’: Learnable positional embeddings
grid_size (int, optional) – Spatial grid size for 2D positional embeddings. If None, inferred as
int(sqrt(num_patches)). Required for non-square grids.drop_path_rate (float, default=0.0) – Stochastic depth rate for regularization during training.
Attributes
----------
mask_token (nn.Parameter) – Learnable token of shape (1, 1, embed_dim) used to represent masked positions. Initialized with truncated normal (std=0.02).
transformer (FlexibleTransformer) – Core transformer module handling attention and projections.
Notes
-----
MAE (- The mask convention follows)
positions (- The decoder receives visible tokens and reconstructs masked)
efficiency (- For)
default (>>> mask_ratio = 0.75 # MAE)
References
----------
He (.. [1]) – CVPR 2022. https://arxiv.org/abs/2111.06377
K. – CVPR 2022. https://arxiv.org/abs/2111.06377
Learners." (et al. "Masked Autoencoders Are Scalable Vision) – CVPR 2022. https://arxiv.org/abs/2111.06377
Examples
--------
Encoder** (**Basic Usage with MAE)
torch (>>> import)
nn (>>> import torch.nn as)
>>>
ViT-Base (>>> # Configuration matching)
B (>>>)
4 (T =)
size (196 # batch)
(14x14) (num_patches)
dimension (>>> embed_dim = 768 # encoder)
default
>>>
decoder (>>> # Initialize)
MAEDecoder( (>>> decoder =)
embed_dim=embed_dim (...)
:param : :param … decoder_embed_dim=512: :param : :param … output_dim=16 * 16 * 3: :param # patch_size² × channels = 768: :param … num_patches=T: :param : :param … depth=4: :param : :param … num_heads=16: :param : :param … ): :param >>>: :param >>> # Simulate encoder output (visible tokens only): :param >>> N_vis = int(T * (1 - mask_ratio)) # 49 visible patches: :param >>> visible_tokens = torch.randn(B: :param N_vis: :param embed_dim): :param >>>: :param >>> # Create random mask (0=visible: :param 1=masked): :param >>> mask = torch.zeros(B: :param T): :param >>> for i in range(B): :param … masked_indices = torch.randperm(T)[: :type … masked_indices = torch.randperm(T)[: T - N_vis] :param … mask[i: :param masked_indices] = 1: :param >>>: :param >>> # Decode - predict masked patches only: :param >>> pred_masked = decoder(visible_tokens: :param mask: :param output_masked_only=True): :param >>> print(pred_masked.shape) # [B: :param N_mask: :param output_dim]: :param torch.Size([4: :param 147: :param 768]): :param **Full Sequence Reconstruction**: :param >>> # Get predictions for ALL positions (for visualization): :param >>> pred_full = decoder(visible_tokens: :param mask: :param output_masked_only=False): :param >>> print(pred_full.shape) # [B: :param T: :param output_dim]: :param torch.Size([4: :param 196: :param 768]): :param **Using Full Sequence Input**: :param If you have the full sequence with mask tokens already inserted: :param >>> full_sequence = torch.randn(B: :param T: :param embed_dim) # [B: :param 196: :param 768]: :param >>> pred = decoder(full_sequence: :param mask: :param output_masked_only=True): :param >>> print(pred.shape): :param torch.Size([4: :param 147: :param 768]): :param **Integration with MAE Training Loop**: :param >>> # Typical MAE training step (pseudocode): :param >>> def mae_forward(encoder: :param decoder: :param images: :param mask_ratio=0.75): :param … # Patchify and mask: :param … patches = patchify(images) # [B: :param T: :param patch_dim]: :param … mask = random_mask(B: :param T: :param mask_ratio) # [B: :param T]: :param 0=keep: :param 1=mask: :param …: :param … # Encode visible patches only: :param … visible_patches = patches[~mask.bool()].reshape(B: :param -1: :param patch_dim): :param … latent = encoder(visible_patches) # [B: :param N_vis: :param embed_dim]: :param …: :param … # Decode to predict masked patches: :param … pred = decoder(: :param … latent: :param mask: :param output_masked_only=True: :param … ) # [B: :param N_mask: :param output_dim]: :param …: :param … # Reconstruction loss on masked patches only: :param … target = patches[mask.bool()].reshape(B: :param -1: :param patch_dim): :param … loss = F.mse_loss(pred: :param target): :param … return loss: :param **Custom Configuration for ViT-Large**: :param >>> decoder_large = MAEDecoder(: :param … embed_dim=1024: :param # ViT-L encoder dim: :param … decoder_embed_dim=512: :param # Keep decoder lightweight: :param … output_dim=768: :param # 16×16×3 pixels: :param … num_patches=256: :param # 16×16 patches for 256×256 images: :param … depth=8: :param # Slightly deeper: :param … num_heads=16: :param : :param … pos_embed_type=”sincos_2d”: :param : :param … drop_path_rate=0.1: :param # Regularization: :param … ): :param See Also: :param ——–: :param FlexibleTransformer: :type FlexibleTransformer: Core transformer implementation used internally.
- forward(x: Tensor, mask: Tensor, output_masked_only: bool = False) Tensor[source]#
Forward pass.
- Parameters:
x – Visible tokens [B, N_vis, D] or full sequence [B, T, D]
mask – Binary mask [B, T], 0=kept, 1=masked
output_masked_only – If True, return [B, N_mask, D]. If False, return [B, T, D].
- Returns:
Predictions
- class stable_pretraining.backbone.vit.MaskedEncoder(model_or_model_name: str | Module = 'vit_base_patch16_224', masking: PatchMasking | None = None, pretrained: bool = False, img_size: int | Tuple[int, int] | None = None, patch_size: int | Tuple[int, int] | None = None, dynamic_img_size: bool = False)[source]#
Bases:
ModuleVision Transformer encoder with optional masking support.
Wraps a timm ViT model and adds flexible masking via
PatchMasking. Handles all ViT internals: patch embedding, positional embeddings, prefix tokens (CLS, registers), and transformer blocks. :param model_or_model_name: timm model name string or pre-instantiated nn.Module :param masking: PatchMasking instance. If None, no masking is applied. :param pretrained: Load pretrained weights (only when model_or_model_name is str) :param img_size: Override default image size :param patch_size: Override default patch size (will reinitialize patch_embed) :param dynamic_img_size: Enable dynamic image size support with pos_embed interpolation Example:from spt.backbone import PatchMasking, MaskedEncoder masking = PatchMasking(mask_ratio=0.75, block_size=4) encoder = MaskedEncoder( model_or_model_name="vit_base_patch16_224", masking=masking, pretrained=True, ) images = torch.randn(4, 3, 224, 224) output = encoder(images) print(output.encoded.shape) # (4, 1 + 49, 768) with 75% masking print(output.mask.shape) # (4, 196) print(output.ids_keep.shape) # (4, 49)
- extra_repr() str[source]#
Return the extra representation of the module.
To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.
- forward(images: Tensor) MaskedEncoderOutput[source]#
Encode images with optional masking.
- Parameters:
images – Input images (B, C, H, W)
- Returns:
MaskedEncoderOutput with encoded tokens and mask info
- class stable_pretraining.backbone.vit.MaskedEncoderOutput(encoded: Tensor, mask: Tensor, ids_keep: Tensor, grid_size: Tuple[int, int])[source]#
Bases:
objectOutput from MaskedEncoder forward pass.
- Variables:
encoded – Encoded token representations (B, num_prefix + N_visible, D)
mask – Binary mask where 1 = masked, 0 = visible (B, N_patches)
ids_keep – Indices of visible patches (B, N_visible)
grid_size – Patch grid dimensions (height, width)
- class stable_pretraining.backbone.vit.PositionalEncoding2D(embed_dim: int, grid_size: Tuple[int, int], pos_type: Literal['learnable', 'sinusoidal', 'rope', 'none'] = 'learnable', num_prefix_tokens: int = 1, learnable: bool | None = None)[source]#
Bases:
ModuleFlexible 2D positional encoding for vision transformers.
- class stable_pretraining.backbone.vit.TransformerBlock(dim: int, num_heads: int, mlp_ratio: float = 4.0, self_attn: bool = True, cross_attn: bool = True, use_adaln: bool = True, drop_path: float = 0.0, attn_drop: float = 0.0, proj_drop: float = 0.0, act_layer: type = <class 'torch.nn.modules.activation.GELU'>)[source]#
Bases:
ModuleUnified transformer block with optional AdaLN-Zero conditioning.
Supports three attention configurations: Mode 1: Pure Cross-Attention (self_attn=False, cross_attn=True)
Queries attend to context but not to each other
Use case: Lightweight decoder
- Mode 2: Decoder-Style (self_attn=True, cross_attn=True)
Self-attention on queries, then cross-attention to context
Use case: Standard decoder (IJEPA predictor, etc.)
- Mode 3: Joint Attention (self_attn=True, cross_attn=False)
All tokens attend to all tokens (caller concatenates context + queries)
Use case: Full bidirectional flow (DiT, high masking ratio)
- Conditioning:
use_adaln=True: AdaLN-Zero modulation (scale, shift, gate per operation)
use_adaln=False: Standard pre-norm transformer
- Parameters:
dim – Hidden dimension
num_heads – Number of attention heads
mlp_ratio – MLP hidden dim = dim * mlp_ratio
self_attn – Enable self-attention
cross_attn – Enable cross-attention
use_adaln – Enable AdaLN-Zero conditioning
drop_path – Stochastic depth rate
attn_drop – Attention dropout rate
proj_drop – Projection dropout rate
act_layer – Activation layer for MLP
- forward(x: Tensor, context: Tensor | None = None, cond: Tensor | None = None) Tensor[source]#
Forward pass.
- Parameters:
x – Input tensor [B, N, D]
context – Context for cross-attention [B, M, D] (required if cross_attn=True)
cond – Conditioning tensor [B, D] (required if use_adaln=True)
- Returns:
Output tensor [B, N, D]
- class stable_pretraining.backbone.vit.TransformerPredictor(input_dim: int, hidden_dim: int, output_dim: int, depth: int, num_heads: int = 6, num_registers: int = 0, mlp_ratio: float = 4.0, drop_path_rate: float = 0.0, pos_embed_type: Literal['sincos_1d', 'sincos_2d', 'learned'] | None = None, max_seq_len: int | None = None)[source]#
Bases:
ModuleLightweight transformer predictor using TransformerBlock.
A flexible predictor module commonly used in masked image modeling (e.g., MAE, I-JEPA). Processes context tokens and optionally includes learnable register/query tokens for aggregation. :param input_dim: Dimension of input context tokens :param hidden_dim: Internal dimension of transformer layers :param output_dim: Dimension of output tokens :param depth: Number of transformer layers :param num_heads: Number of attention heads :param num_registers: Number of learnable register/query tokens to prepend :param mlp_ratio: MLP hidden dimension multiplier :param drop_path_rate: Stochastic depth rate :param pos_embed_type: Type of positional embedding (None, ‘sincos_1d’, ‘sincos_2d’, ‘learned’) :param max_seq_len: Maximum sequence length (required if pos_embed_type=’learned’)
- forward(context: Tensor, pos_embed: Tensor | None = None, ids_keep: Tensor | None = None, grid_size: tuple[int, int] | None = None) Tensor[source]#
Forward pass.
- Parameters:
context – Context tokens [B, N, input_dim]
pos_embed – External positional embeddings [B, N, input_dim] (when pos_embed_type=None)
ids_keep – Indices of kept positions [B, N] (when pos_embed_type is not None)
grid_size – Grid size (H, W) for sincos_2d
- Returns:
Output tokens [B, num_registers + N, output_dim]