Module#

class stable_pretraining.module.Module(*args, forward: callable = None, hparams: dict = None, parallelize_fn: callable = None, **kwargs)[source]#

Bases: LightningModule

PyTorch Lightning module using manual optimization with multi-optimizer support.

Core usage

  • Provide a custom forward(self, batch, stage) via the forward argument at init.

  • During training, forward must return a dict with state["loss"] (a single joint loss). When multiple optimizers are configured, this joint loss is used for all optimizers.

Optimizer configuration (self.optim)

  • Single optimizer:

    {"optimizer": str|dict|partial|Class,
     "scheduler": <see below>,
     "interval": "step"|"epoch",
     "frequency": int}
    

    Optimizer accepted forms:

    • string name (e.g., "AdamW", "SGD") from torch.optim

    • dict: {"type": "AdamW", "lr": 1e-3, ...}

    • functools.partial: partial(torch.optim.AdamW, lr=1e-3)

    • optimizer class: torch.optim.AdamW

  • Multiple optimizers:

    {
      name: {
        "modules": "regex",                  # assign params by module-name pattern (children inherit)
        "optimizer": str|dict|partial|Class, # optimizer factory (same accepted forms as above)
        "scheduler": str|dict|partial|Class, # flexible scheduler config (see below)
        "interval": "step"|"epoch",          # scheduler interval
        "frequency": int,                    # optimizer step frequency
        "monitor": str                       # (optional) for ReduceLROnPlateau
      }, ...
    }
    

Parameter assignment (multi-optimizer)

  • Modules are matched by regex on their qualified name. Children inherit the parent’s assignment unless they match a more specific pattern. Only direct parameters of each module are collected to avoid duplication.

Schedulers (flexible)

  • Accepted forms: string name (e.g., "CosineAnnealingLR", "StepLR"), dict with {"type": "...", ...}, functools.partial, or a scheduler class. Smart defaults are applied when params are omitted for common schedulers (CosineAnnealingLR, OneCycleLR, StepLR, ExponentialLR, ReduceLROnPlateau, LinearLR, ConstantLR). For ReduceLROnPlateau, a monitor key is added (default: "val_loss"). You may specify monitor either alongside the optimizer config (top level) or inside the scheduler dict itself.

  • The resulting Lightning scheduler dict includes interval and frequency (or scheduler_frequency).

Training loop behavior

  • Manual optimization (automatic_optimization = False).

  • Gradient accumulation: scales loss by 1/N where N = Trainer.accumulate_grad_batches and steps on the boundary.

  • Per-optimizer step frequency: each optimizer steps only when its frequency boundary is met (in addition to accumulation boundary).

  • Gradient clipping: uses Trainer’s gradient_clip_val and gradient_clip_algorithm before each step.

  • Returns the state dict from forward unchanged for logging/inspection.

after_manual_backward()[source]#

Hook called immediately after manual_backward in training_step.

Override in a subclass to insert logic that must run after gradients are computed but before any optimizer step or zero_grad — for example, gradient norm logging, custom gradient clipping, or EMA teacher weight updates that depend on the current gradient. The default implementation does nothing.

configure_model() None[source]#

Lightning hook for FSDP2 — shard the model when a device mesh exists.

Under Trainer(strategy="fsdp2"), ModelParallelStrategy builds a device mesh and exposes it as self.device_mesh before calling this hook. We dispatch to the configured parallelize_fn (default: stable_pretraining.utils.fsdp2.default_parallelize_fn()) to apply fully_shard. Under any other strategy (single-device / DDP) there is no mesh and this is a no-op.

Runs before configure_optimizers, so optimizers are built over the sharded DTensor parameters, as FSDP2 requires.

configure_optimizers()[source]#

Configure optimizers and schedulers for manual optimization.

Returns:

Optimizer configuration with optional learning rate scheduler. For single optimizer: Returns a dict with optimizer and lr_scheduler. For multiple optimizers: Returns a tuple of (optimizers, schedulers).

Return type:

dict or tuple

Example

Multi-optimizer configuration with module pattern matching and schedulers:

>>> # Simple single optimizer with scheduler
>>> self.optim = {
...     "optimizer": partial(torch.optim.AdamW, lr=1e-3),
...     "scheduler": "CosineAnnealingLR",  # Uses smart defaults
...     "interval": "step",
...     "frequency": 1,
... }
>>> # Multi-optimizer with custom scheduler configs
>>> self.optim = {
...     "encoder_opt": {
...         "modules": "encoder",  # Matches 'encoder' and all children
...         "optimizer": {"type": "AdamW", "lr": 1e-3},
...         "scheduler": {
...             "type": "OneCycleLR",
...             "max_lr": 1e-3,
...             "total_steps": 10000,
...         },
...         "interval": "step",
...         "frequency": 1,
...     },
...     "head_opt": {
...         "modules": ".*head$",  # Matches modules ending with 'head'
...         "optimizer": "SGD",
...         "scheduler": {
...             "type": "ReduceLROnPlateau",
...             "mode": "max",
...             "patience": 5,
...             "factor": 0.5,
...         },
...         "monitor": "val_accuracy",  # Required for ReduceLROnPlateau
...         "interval": "epoch",
...         "frequency": 2,
...     },
... }

With model structure: - encoder -> encoder_opt (matches “encoder”) - encoder.layer1 -> encoder_opt (inherits from parent) - encoder.layer1.conv -> encoder_opt (inherits from encoder.layer1) - classifier_head -> head_opt (matches “.*head$”) - classifier_head.linear -> head_opt (inherits from parent) - decoder -> None (no match, no parameters collected)

forward(*args, **kwargs)[source]#

Same as torch.nn.Module.forward().

Parameters:
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

Returns:

Your model’s output

named_parameters(with_callbacks=True, prefix: str = '', recurse: bool = True, remove_duplicate: bool = True)[source]#

Override to globally exclude callback-related parameters.

Excludes parameters that belong to self.callbacks_modules or self.callbacks_metrics. This prevents accidental optimization of callback/metric internals, even if external code calls self.parameters() or self.named_parameters() directly.

Parameters:
  • with_callbacks (bool, optional) – If False, excludes callback parameters. Defaults to True.

  • prefix (str, optional) – Prefix to prepend to parameter names. Defaults to “”.

  • recurse (bool, optional) – If True, yields parameters of this module and all submodules. If False, yields only direct parameters. Defaults to True.

  • remove_duplicate (bool, optional) – Whether to deduplicate shared parameters. Must be accepted and forwarded because PyTorch’s fully_shard (FSDP2) wrap path calls named_parameters(remove_duplicate=False); without it this override would raise TypeError before sharding. Defaults to True.

Yields:

tuple[str, torch.nn.Parameter] – Name and parameter pairs.

on_after_batch_transfer(batch, dataloader_idx)[source]#

Apply the active GPU-side batch transform after Lightning moves the batch.

Resolution order (first match wins):

  1. dataset.gpu_transform — set on the dataset behind the active DataLoader. Recommended: pair augmentation with the dataset so train/val/test/predict each carry their own spec naturally.

  2. self.trainer.datamodule.gpu_transform — set on the DataModule. May be a callable or a {"train": ..., "val": ...} dict.

Setting self.gpu_transform on the Module is not supported and is rejected at on_train_start; attach it to the dataset (or the DataModule for third-party datasets) instead.

Lazy device placement: when the resolved transform is an nn.Module (e.g. a GPUCompose), it is moved to self.device on first use. Buffers (e.g. GPUNormalize’s mean/std) therefore live on the correct GPU under DDP without manual wiring.

When nothing resolves, this is a zero-cost passthrough.

Parameters:
  • batch – Batch dict already moved to self.device by Lightning.

  • dataloader_idx – Index of the dataloader that produced this batch.

Returns:

The (possibly augmented) batch dict.

on_train_start()[source]#

Validate and log the optimizer configuration at the start of training.

Runs once before the first training step. Fills in any missing per-optimizer metadata (gradient clip value, clip algorithm, step frequency) by falling back to the Trainer’s global settings. Logs a summary table of optimizer index, name, class, clip value, and clip algorithm so misconfigured setups are caught early rather than silently misbehaving mid-run.

parameters(with_callbacks=True, recurse: bool = True)[source]#

Override to route through the filtered named_parameters implementation.

Parameters:
  • with_callbacks (bool, optional) – If False, excludes callback parameters. Defaults to True.

  • recurse (bool, optional) – If True, yields parameters of this module and all submodules. If False, yields only direct parameters. Defaults to True.

Yields:

torch.nn.Parameter – Module parameters.

predict_step(batch, batch_idx)[source]#

Run the forward pass for a single prediction batch.

Passes stage="predict" to forward so forward functions can omit loss computation and return only inference outputs (e.g., embeddings). Used by Trainer.predict() for large-scale feature extraction without a label set.

Parameters:
  • batch – Input batch dict from the prediction dataloader.

  • batch_idx – Index of the current batch within the epoch.

Returns:

Output dict returned by forward.

Return type:

dict

setup(stage: str) None[source]#

Lightning setup hook — reject the legacy FSDP1 strategy.

FSDP1 (Trainer(strategy="fsdp")) uses a flat training-state machine that asserts a single forward/backward per step, which breaks the multi-forward methods that are this library’s bread and butter (DINO, I-JEPA, LeJEPA, …). FSDP2 (strategy="fsdp2") has no such restriction, so we fail fast with a clear redirect.

Parameters:

stage – The Lightning stage ("fit"/"validate"/"test"/"predict").

test_step(batch, batch_idx)[source]#

Run the forward pass for a single test batch.

Mirrors validation_step but passes stage="test" to forward, allowing forward functions to distinguish test-time behaviour if needed. The returned dict is forwarded to Lightning’s on_test_batch_end callback hooks.

Parameters:
  • batch – Input batch dict from the test dataloader.

  • batch_idx – Index of the current batch within the epoch.

Returns:

Output dict returned by forward.

Return type:

dict

training_step(batch, batch_idx)[source]#

Run one training step with manual optimization across all configured optimizers.

Calls forward(batch, stage="fit") to obtain a state dict, then performs a single manual_backward on state["loss"]. Each optimizer steps only when its frequency boundary is met ((batch_idx + 1) % frequency == 0). Gradient clipping is applied per-optimizer using either the per-optimizer override or the Trainer’s gradient_clip_val. Learning rate is logged as hparams/lr_{name} after each step. zero_grad is called only on optimizers that actually stepped this iteration.

Parameters:
  • batch – Input batch dict from the training dataloader. Must be a dict — a non-dict batch raises ValueError.

  • batch_idx – Index of the current batch within the epoch. Injected into the batch dict as batch["batch_idx"] before forwarding.

Returns:

The state dict returned by forward, passed unchanged to

Lightning’s callback hooks (on_train_batch_end).

Return type:

dict

validation_step(batch, batch_idx)[source]#

Run the forward pass for a single validation batch.

Calls forward(batch, stage="validate") with gradients disabled (Lightning handles torch.no_grad()). The returned dict is passed to every registered callback via on_validation_batch_end, making all keys — including "embedding" and "label" — available to OnlineProbe, OnlineKNN, RankMe, and similar evaluation callbacks without any extra wiring.

Parameters:
  • batch – Input batch dict from the validation dataloader.

  • batch_idx – Index of the current batch within the epoch.

Returns:

Output dict returned by forward.

Return type:

dict