Module#

class stable_pretraining.module.Module(*args, forward: callable = None, hparams: dict = None, **kwargs)[source]#

Bases: LightningModule

PyTorch Lightning module using manual optimization with multi-optimizer support.

Core usage

  • Provide a custom forward(self, batch, stage) via the forward argument at init.

  • During training, forward must return a dict with state["loss"] (a single joint loss). When multiple optimizers are configured, this joint loss is used for all optimizers.

Optimizer configuration (self.optim)

  • Single optimizer:

    {"optimizer": str|dict|partial|Class,
     "scheduler": <see below>,
     "interval": "step"|"epoch",
     "frequency": int}
    

    Optimizer accepted forms:

    • string name (e.g., "AdamW", "SGD") from torch.optim

    • dict: {"type": "AdamW", "lr": 1e-3, ...}

    • functools.partial: partial(torch.optim.AdamW, lr=1e-3)

    • optimizer class: torch.optim.AdamW

  • Multiple optimizers:

    {
      name: {
        "modules": "regex",                  # assign params by module-name pattern (children inherit)
        "optimizer": str|dict|partial|Class, # optimizer factory (same accepted forms as above)
        "scheduler": str|dict|partial|Class, # flexible scheduler config (see below)
        "interval": "step"|"epoch",          # scheduler interval
        "frequency": int,                    # optimizer step frequency
        "monitor": str                       # (optional) for ReduceLROnPlateau
      }, ...
    }
    

Parameter assignment (multi-optimizer)

  • Modules are matched by regex on their qualified name. Children inherit the parent’s assignment unless they match a more specific pattern. Only direct parameters of each module are collected to avoid duplication.

Schedulers (flexible)

  • Accepted forms: string name (e.g., "CosineAnnealingLR", "StepLR"), dict with {"type": "...", ...}, functools.partial, or a scheduler class. Smart defaults are applied when params are omitted for common schedulers (CosineAnnealingLR, OneCycleLR, StepLR, ExponentialLR, ReduceLROnPlateau, LinearLR, ConstantLR). For ReduceLROnPlateau, a monitor key is added (default: "val_loss"). You may specify monitor either alongside the optimizer config (top level) or inside the scheduler dict itself.

  • The resulting Lightning scheduler dict includes interval and frequency (or scheduler_frequency).

Training loop behavior

  • Manual optimization (automatic_optimization = False).

  • Gradient accumulation: scales loss by 1/N where N = Trainer.accumulate_grad_batches and steps on the boundary.

  • Per-optimizer step frequency: each optimizer steps only when its frequency boundary is met (in addition to accumulation boundary).

  • Gradient clipping: uses Trainer’s gradient_clip_val and gradient_clip_algorithm before each step.

  • Returns the state dict from forward unchanged for logging/inspection.

after_manual_backward()[source]#

Hook called immediately after manual_backward in training_step.

Override in a subclass to insert logic that must run after gradients are computed but before any optimizer step or zero_grad — for example, gradient norm logging, custom gradient clipping, or EMA teacher weight updates that depend on the current gradient. The default implementation does nothing.

configure_optimizers()[source]#

Configure optimizers and schedulers for manual optimization.

Returns:

Optimizer configuration with optional learning rate scheduler. For single optimizer: Returns a dict with optimizer and lr_scheduler. For multiple optimizers: Returns a tuple of (optimizers, schedulers).

Return type:

dict or tuple

Example

Multi-optimizer configuration with module pattern matching and schedulers:

>>> # Simple single optimizer with scheduler
>>> self.optim = {
...     "optimizer": partial(torch.optim.AdamW, lr=1e-3),
...     "scheduler": "CosineAnnealingLR",  # Uses smart defaults
...     "interval": "step",
...     "frequency": 1,
... }
>>> # Multi-optimizer with custom scheduler configs
>>> self.optim = {
...     "encoder_opt": {
...         "modules": "encoder",  # Matches 'encoder' and all children
...         "optimizer": {"type": "AdamW", "lr": 1e-3},
...         "scheduler": {
...             "type": "OneCycleLR",
...             "max_lr": 1e-3,
...             "total_steps": 10000,
...         },
...         "interval": "step",
...         "frequency": 1,
...     },
...     "head_opt": {
...         "modules": ".*head$",  # Matches modules ending with 'head'
...         "optimizer": "SGD",
...         "scheduler": {
...             "type": "ReduceLROnPlateau",
...             "mode": "max",
...             "patience": 5,
...             "factor": 0.5,
...         },
...         "monitor": "val_accuracy",  # Required for ReduceLROnPlateau
...         "interval": "epoch",
...         "frequency": 2,
...     },
... }

With model structure: - encoder -> encoder_opt (matches “encoder”) - encoder.layer1 -> encoder_opt (inherits from parent) - encoder.layer1.conv -> encoder_opt (inherits from encoder.layer1) - classifier_head -> head_opt (matches “.*head$”) - classifier_head.linear -> head_opt (inherits from parent) - decoder -> None (no match, no parameters collected)

forward(*args, **kwargs)[source]#

Same as torch.nn.Module.forward().

Parameters:
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

Returns:

Your model’s output

named_parameters(with_callbacks=True, prefix: str = '', recurse: bool = True)[source]#

Override to globally exclude callback-related parameters.

Excludes parameters that belong to self.callbacks_modules or self.callbacks_metrics. This prevents accidental optimization of callback/metric internals, even if external code calls self.parameters() or self.named_parameters() directly.

Parameters:
  • with_callbacks (bool, optional) – If False, excludes callback parameters. Defaults to True.

  • prefix (str, optional) – Prefix to prepend to parameter names. Defaults to “”.

  • recurse (bool, optional) – If True, yields parameters of this module and all submodules. If False, yields only direct parameters. Defaults to True.

Yields:

tuple[str, torch.nn.Parameter] – Name and parameter pairs.

on_train_start()[source]#

Validate and log the optimizer configuration at the start of training.

Runs once before the first training step. Fills in any missing per-optimizer metadata (gradient clip value, clip algorithm, step frequency) by falling back to the Trainer’s global settings. Logs a summary table of optimizer index, name, class, clip value, and clip algorithm so misconfigured setups are caught early rather than silently misbehaving mid-run.

parameters(with_callbacks=True, recurse: bool = True)[source]#

Override to route through the filtered named_parameters implementation.

Parameters:
  • with_callbacks (bool, optional) – If False, excludes callback parameters. Defaults to True.

  • recurse (bool, optional) – If True, yields parameters of this module and all submodules. If False, yields only direct parameters. Defaults to True.

Yields:

torch.nn.Parameter – Module parameters.

predict_step(batch, batch_idx)[source]#

Run the forward pass for a single prediction batch.

Passes stage="predict" to forward so forward functions can omit loss computation and return only inference outputs (e.g., embeddings). Used by Trainer.predict() for large-scale feature extraction without a label set.

Parameters:
  • batch – Input batch dict from the prediction dataloader.

  • batch_idx – Index of the current batch within the epoch.

Returns:

Output dict returned by forward.

Return type:

dict

rescale_loss_for_grad_acc(loss)[source]#

Scale loss down by the gradient accumulation factor before manual_backward.

When Trainer(accumulate_grad_batches=N) is set, gradients from N consecutive steps are summed before an optimizer step. Dividing the loss by N ensures the accumulated gradient is equivalent in magnitude to the gradient from a single full batch, preventing the effective learning rate from growing with N.

Parameters:

loss – The raw loss tensor returned by forward.

Returns:

loss / accumulate_grad_batches.

Return type:

torch.Tensor

test_step(batch, batch_idx)[source]#

Run the forward pass for a single test batch.

Mirrors validation_step but passes stage="test" to forward, allowing forward functions to distinguish test-time behaviour if needed. The returned dict is forwarded to Lightning’s on_test_batch_end callback hooks.

Parameters:
  • batch – Input batch dict from the test dataloader.

  • batch_idx – Index of the current batch within the epoch.

Returns:

Output dict returned by forward.

Return type:

dict

training_step(batch, batch_idx)[source]#

Run one training step with manual optimization across all configured optimizers.

Calls forward(batch, stage="fit") to obtain a state dict, then performs a single manual_backward on state["loss"]. Each optimizer steps only when its frequency boundary is met ((batch_idx + 1) % frequency == 0). Gradient clipping is applied per-optimizer using either the per-optimizer override or the Trainer’s gradient_clip_val. Learning rate is logged as hparams/lr_{name} after each step. zero_grad is called only on optimizers that actually stepped this iteration.

Parameters:
  • batch – Input batch dict from the training dataloader. Must be a dict — a non-dict batch raises ValueError.

  • batch_idx – Index of the current batch within the epoch. Injected into the batch dict as batch["batch_idx"] before forwarding.

Returns:

The state dict returned by forward, passed unchanged to

Lightning’s callback hooks (on_train_batch_end).

Return type:

dict

validation_step(batch, batch_idx)[source]#

Run the forward pass for a single validation batch.

Calls forward(batch, stage="validate") with gradients disabled (Lightning handles torch.no_grad()). The returned dict is passed to every registered callback via on_validation_batch_end, making all keys — including "embedding" and "label" — available to OnlineProbe, OnlineKNN, RankMe, and similar evaluation callbacks without any extra wiring.

Parameters:
  • batch – Input batch dict from the validation dataloader.

  • batch_idx – Index of the current batch within the epoch.

Returns:

Output dict returned by forward.

Return type:

dict