Module#
- class stable_pretraining.module.Module(*args, forward: callable = None, hparams: dict = None, **kwargs)[source]#
Bases:
LightningModulePyTorch Lightning module using manual optimization with multi-optimizer support.
Core usage
Provide a custom
forward(self, batch, stage)via theforwardargument at init.During training,
forwardmust return a dict withstate["loss"](a single joint loss). When multiple optimizers are configured, this joint loss is used for all optimizers.
Optimizer configuration (
self.optim)Single optimizer:
{"optimizer": str|dict|partial|Class, "scheduler": <see below>, "interval": "step"|"epoch", "frequency": int}
Optimizer accepted forms:
string name (e.g.,
"AdamW","SGD") fromtorch.optimdict:
{"type": "AdamW", "lr": 1e-3, ...}functools.partial:partial(torch.optim.AdamW, lr=1e-3)optimizer class:
torch.optim.AdamW
Multiple optimizers:
{ name: { "modules": "regex", # assign params by module-name pattern (children inherit) "optimizer": str|dict|partial|Class, # optimizer factory (same accepted forms as above) "scheduler": str|dict|partial|Class, # flexible scheduler config (see below) "interval": "step"|"epoch", # scheduler interval "frequency": int, # optimizer step frequency "monitor": str # (optional) for ReduceLROnPlateau }, ... }
Parameter assignment (multi-optimizer)
Modules are matched by regex on their qualified name. Children inherit the parent’s assignment unless they match a more specific pattern. Only direct parameters of each module are collected to avoid duplication.
Schedulers (flexible)
Accepted forms: string name (e.g.,
"CosineAnnealingLR","StepLR"), dict with{"type": "...", ...},functools.partial, or a scheduler class. Smart defaults are applied when params are omitted for common schedulers (CosineAnnealingLR,OneCycleLR,StepLR,ExponentialLR,ReduceLROnPlateau,LinearLR,ConstantLR). ForReduceLROnPlateau, amonitorkey is added (default:"val_loss"). You may specifymonitoreither alongside the optimizer config (top level) or inside the scheduler dict itself.The resulting Lightning scheduler dict includes
intervalandfrequency(orscheduler_frequency).
Training loop behavior
Manual optimization (
automatic_optimization = False).Gradient accumulation: scales loss by
1/NwhereN = Trainer.accumulate_grad_batchesand steps on the boundary.Per-optimizer step frequency: each optimizer steps only when its frequency boundary is met (in addition to accumulation boundary).
Gradient clipping: uses Trainer’s
gradient_clip_valandgradient_clip_algorithmbefore each step.Returns the
statedict fromforwardunchanged for logging/inspection.
- after_manual_backward()[source]#
Hook called immediately after
manual_backwardintraining_step.Override in a subclass to insert logic that must run after gradients are computed but before any optimizer step or
zero_grad— for example, gradient norm logging, custom gradient clipping, or EMA teacher weight updates that depend on the current gradient. The default implementation does nothing.
- configure_optimizers()[source]#
Configure optimizers and schedulers for manual optimization.
- Returns:
Optimizer configuration with optional learning rate scheduler. For single optimizer: Returns a dict with optimizer and lr_scheduler. For multiple optimizers: Returns a tuple of (optimizers, schedulers).
- Return type:
Example
Multi-optimizer configuration with module pattern matching and schedulers:
>>> # Simple single optimizer with scheduler >>> self.optim = { ... "optimizer": partial(torch.optim.AdamW, lr=1e-3), ... "scheduler": "CosineAnnealingLR", # Uses smart defaults ... "interval": "step", ... "frequency": 1, ... }
>>> # Multi-optimizer with custom scheduler configs >>> self.optim = { ... "encoder_opt": { ... "modules": "encoder", # Matches 'encoder' and all children ... "optimizer": {"type": "AdamW", "lr": 1e-3}, ... "scheduler": { ... "type": "OneCycleLR", ... "max_lr": 1e-3, ... "total_steps": 10000, ... }, ... "interval": "step", ... "frequency": 1, ... }, ... "head_opt": { ... "modules": ".*head$", # Matches modules ending with 'head' ... "optimizer": "SGD", ... "scheduler": { ... "type": "ReduceLROnPlateau", ... "mode": "max", ... "patience": 5, ... "factor": 0.5, ... }, ... "monitor": "val_accuracy", # Required for ReduceLROnPlateau ... "interval": "epoch", ... "frequency": 2, ... }, ... }
With model structure: - encoder -> encoder_opt (matches “encoder”) - encoder.layer1 -> encoder_opt (inherits from parent) - encoder.layer1.conv -> encoder_opt (inherits from encoder.layer1) - classifier_head -> head_opt (matches “.*head$”) - classifier_head.linear -> head_opt (inherits from parent) - decoder -> None (no match, no parameters collected)
- forward(*args, **kwargs)[source]#
Same as
torch.nn.Module.forward().- Parameters:
*args – Whatever you decide to pass into the forward method.
**kwargs – Keyword arguments are also possible.
- Returns:
Your model’s output
- named_parameters(with_callbacks=True, prefix: str = '', recurse: bool = True)[source]#
Override to globally exclude callback-related parameters.
Excludes parameters that belong to
self.callbacks_modulesorself.callbacks_metrics. This prevents accidental optimization of callback/metric internals, even if external code callsself.parameters()orself.named_parameters()directly.- Parameters:
with_callbacks (bool, optional) – If False, excludes callback parameters. Defaults to True.
prefix (str, optional) – Prefix to prepend to parameter names. Defaults to “”.
recurse (bool, optional) – If True, yields parameters of this module and all submodules. If False, yields only direct parameters. Defaults to True.
- Yields:
tuple[str, torch.nn.Parameter] – Name and parameter pairs.
- on_train_start()[source]#
Validate and log the optimizer configuration at the start of training.
Runs once before the first training step. Fills in any missing per-optimizer metadata (gradient clip value, clip algorithm, step frequency) by falling back to the Trainer’s global settings. Logs a summary table of optimizer index, name, class, clip value, and clip algorithm so misconfigured setups are caught early rather than silently misbehaving mid-run.
- parameters(with_callbacks=True, recurse: bool = True)[source]#
Override to route through the filtered
named_parametersimplementation.
- predict_step(batch, batch_idx)[source]#
Run the forward pass for a single prediction batch.
Passes
stage="predict"toforwardso forward functions can omit loss computation and return only inference outputs (e.g., embeddings). Used byTrainer.predict()for large-scale feature extraction without a label set.- Parameters:
batch – Input batch dict from the prediction dataloader.
batch_idx – Index of the current batch within the epoch.
- Returns:
Output dict returned by
forward.- Return type:
- rescale_loss_for_grad_acc(loss)[source]#
Scale loss down by the gradient accumulation factor before
manual_backward.When
Trainer(accumulate_grad_batches=N)is set, gradients from N consecutive steps are summed before an optimizer step. Dividing the loss by N ensures the accumulated gradient is equivalent in magnitude to the gradient from a single full batch, preventing the effective learning rate from growing with N.- Parameters:
loss – The raw loss tensor returned by
forward.- Returns:
loss / accumulate_grad_batches.- Return type:
- test_step(batch, batch_idx)[source]#
Run the forward pass for a single test batch.
Mirrors
validation_stepbut passesstage="test"toforward, allowing forward functions to distinguish test-time behaviour if needed. The returned dict is forwarded to Lightning’son_test_batch_endcallback hooks.- Parameters:
batch – Input batch dict from the test dataloader.
batch_idx – Index of the current batch within the epoch.
- Returns:
Output dict returned by
forward.- Return type:
- training_step(batch, batch_idx)[source]#
Run one training step with manual optimization across all configured optimizers.
Calls
forward(batch, stage="fit")to obtain astatedict, then performs a singlemanual_backwardonstate["loss"]. Each optimizer steps only when its frequency boundary is met ((batch_idx + 1) % frequency == 0). Gradient clipping is applied per-optimizer using either the per-optimizer override or the Trainer’sgradient_clip_val. Learning rate is logged ashparams/lr_{name}after each step.zero_gradis called only on optimizers that actually stepped this iteration.- Parameters:
batch – Input batch dict from the training dataloader. Must be a
dict— a non-dict batch raisesValueError.batch_idx – Index of the current batch within the epoch. Injected into the batch dict as
batch["batch_idx"]before forwarding.
- Returns:
- The
statedict returned byforward, passed unchanged to Lightning’s callback hooks (
on_train_batch_end).
- The
- Return type:
- validation_step(batch, batch_idx)[source]#
Run the forward pass for a single validation batch.
Calls
forward(batch, stage="validate")with gradients disabled (Lightning handlestorch.no_grad()). The returned dict is passed to every registered callback viaon_validation_batch_end, making all keys — including"embedding"and"label"— available toOnlineProbe,OnlineKNN,RankMe, and similar evaluation callbacks without any extra wiring.- Parameters:
batch – Input batch dict from the validation dataloader.
batch_idx – Index of the current batch within the epoch.
- Returns:
Output dict returned by
forward.- Return type: