Module#
- class stable_pretraining.module.Module(*args, forward: callable = None, hparams: dict = None, parallelize_fn: callable = None, **kwargs)[source]#
Bases:
LightningModulePyTorch Lightning module using manual optimization with multi-optimizer support.
Core usage
Provide a custom
forward(self, batch, stage)via theforwardargument at init.During training,
forwardmust return a dict withstate["loss"](a single joint loss). When multiple optimizers are configured, this joint loss is used for all optimizers.
Optimizer configuration (
self.optim)Single optimizer:
{"optimizer": str|dict|partial|Class, "scheduler": <see below>, "interval": "step"|"epoch", "frequency": int}
Optimizer accepted forms:
string name (e.g.,
"AdamW","SGD") fromtorch.optimdict:
{"type": "AdamW", "lr": 1e-3, ...}functools.partial:partial(torch.optim.AdamW, lr=1e-3)optimizer class:
torch.optim.AdamW
Multiple optimizers:
{ name: { "modules": "regex", # assign params by module-name pattern (children inherit) "optimizer": str|dict|partial|Class, # optimizer factory (same accepted forms as above) "scheduler": str|dict|partial|Class, # flexible scheduler config (see below) "interval": "step"|"epoch", # scheduler interval "frequency": int, # optimizer step frequency "monitor": str # (optional) for ReduceLROnPlateau }, ... }
Parameter assignment (multi-optimizer)
Modules are matched by regex on their qualified name. Children inherit the parent’s assignment unless they match a more specific pattern. Only direct parameters of each module are collected to avoid duplication.
Schedulers (flexible)
Accepted forms: string name (e.g.,
"CosineAnnealingLR","StepLR"), dict with{"type": "...", ...},functools.partial, or a scheduler class. Smart defaults are applied when params are omitted for common schedulers (CosineAnnealingLR,OneCycleLR,StepLR,ExponentialLR,ReduceLROnPlateau,LinearLR,ConstantLR). ForReduceLROnPlateau, amonitorkey is added (default:"val_loss"). You may specifymonitoreither alongside the optimizer config (top level) or inside the scheduler dict itself.The resulting Lightning scheduler dict includes
intervalandfrequency(orscheduler_frequency).
Training loop behavior
Manual optimization (
automatic_optimization = False).Gradient accumulation: scales loss by
1/NwhereN = Trainer.accumulate_grad_batchesand steps on the boundary.Per-optimizer step frequency: each optimizer steps only when its frequency boundary is met (in addition to accumulation boundary).
Gradient clipping: uses Trainer’s
gradient_clip_valandgradient_clip_algorithmbefore each step.Returns the
statedict fromforwardunchanged for logging/inspection.
- after_manual_backward()[source]#
Hook called immediately after
manual_backwardintraining_step.Override in a subclass to insert logic that must run after gradients are computed but before any optimizer step or
zero_grad— for example, gradient norm logging, custom gradient clipping, or EMA teacher weight updates that depend on the current gradient. The default implementation does nothing.
- configure_model() None[source]#
Lightning hook for FSDP2 — shard the model when a device mesh exists.
Under
Trainer(strategy="fsdp2"),ModelParallelStrategybuilds a device mesh and exposes it asself.device_meshbefore calling this hook. We dispatch to the configuredparallelize_fn(default:stable_pretraining.utils.fsdp2.default_parallelize_fn()) to applyfully_shard. Under any other strategy (single-device / DDP) there is no mesh and this is a no-op.Runs before
configure_optimizers, so optimizers are built over the shardedDTensorparameters, as FSDP2 requires.
- configure_optimizers()[source]#
Configure optimizers and schedulers for manual optimization.
- Returns:
Optimizer configuration with optional learning rate scheduler. For single optimizer: Returns a dict with optimizer and lr_scheduler. For multiple optimizers: Returns a tuple of (optimizers, schedulers).
- Return type:
Example
Multi-optimizer configuration with module pattern matching and schedulers:
>>> # Simple single optimizer with scheduler >>> self.optim = { ... "optimizer": partial(torch.optim.AdamW, lr=1e-3), ... "scheduler": "CosineAnnealingLR", # Uses smart defaults ... "interval": "step", ... "frequency": 1, ... }
>>> # Multi-optimizer with custom scheduler configs >>> self.optim = { ... "encoder_opt": { ... "modules": "encoder", # Matches 'encoder' and all children ... "optimizer": {"type": "AdamW", "lr": 1e-3}, ... "scheduler": { ... "type": "OneCycleLR", ... "max_lr": 1e-3, ... "total_steps": 10000, ... }, ... "interval": "step", ... "frequency": 1, ... }, ... "head_opt": { ... "modules": ".*head$", # Matches modules ending with 'head' ... "optimizer": "SGD", ... "scheduler": { ... "type": "ReduceLROnPlateau", ... "mode": "max", ... "patience": 5, ... "factor": 0.5, ... }, ... "monitor": "val_accuracy", # Required for ReduceLROnPlateau ... "interval": "epoch", ... "frequency": 2, ... }, ... }
With model structure: - encoder -> encoder_opt (matches “encoder”) - encoder.layer1 -> encoder_opt (inherits from parent) - encoder.layer1.conv -> encoder_opt (inherits from encoder.layer1) - classifier_head -> head_opt (matches “.*head$”) - classifier_head.linear -> head_opt (inherits from parent) - decoder -> None (no match, no parameters collected)
- forward(*args, **kwargs)[source]#
Same as
torch.nn.Module.forward().- Parameters:
*args – Whatever you decide to pass into the forward method.
**kwargs – Keyword arguments are also possible.
- Returns:
Your model’s output
- named_parameters(with_callbacks=True, prefix: str = '', recurse: bool = True, remove_duplicate: bool = True)[source]#
Override to globally exclude callback-related parameters.
Excludes parameters that belong to
self.callbacks_modulesorself.callbacks_metrics. This prevents accidental optimization of callback/metric internals, even if external code callsself.parameters()orself.named_parameters()directly.- Parameters:
with_callbacks (bool, optional) – If False, excludes callback parameters. Defaults to True.
prefix (str, optional) – Prefix to prepend to parameter names. Defaults to “”.
recurse (bool, optional) – If True, yields parameters of this module and all submodules. If False, yields only direct parameters. Defaults to True.
remove_duplicate (bool, optional) – Whether to deduplicate shared parameters. Must be accepted and forwarded because PyTorch’s
fully_shard(FSDP2) wrap path callsnamed_parameters(remove_duplicate=False); without it this override would raiseTypeErrorbefore sharding. Defaults to True.
- Yields:
tuple[str, torch.nn.Parameter] – Name and parameter pairs.
- on_after_batch_transfer(batch, dataloader_idx)[source]#
Apply the active GPU-side batch transform after Lightning moves the batch.
Resolution order (first match wins):
dataset.gpu_transform— set on the dataset behind the active DataLoader. Recommended: pair augmentation with the dataset so train/val/test/predict each carry their own spec naturally.self.trainer.datamodule.gpu_transform— set on the DataModule. May be a callable or a{"train": ..., "val": ...}dict.
Setting
self.gpu_transformon the Module is not supported and is rejected aton_train_start; attach it to the dataset (or the DataModule for third-party datasets) instead.Lazy device placement: when the resolved transform is an
nn.Module(e.g. aGPUCompose), it is moved toself.deviceon first use. Buffers (e.g.GPUNormalize’s mean/std) therefore live on the correct GPU under DDP without manual wiring.When nothing resolves, this is a zero-cost passthrough.
- Parameters:
batch – Batch dict already moved to
self.deviceby Lightning.dataloader_idx – Index of the dataloader that produced this batch.
- Returns:
The (possibly augmented) batch dict.
- on_train_start()[source]#
Validate and log the optimizer configuration at the start of training.
Runs once before the first training step. Fills in any missing per-optimizer metadata (gradient clip value, clip algorithm, step frequency) by falling back to the Trainer’s global settings. Logs a summary table of optimizer index, name, class, clip value, and clip algorithm so misconfigured setups are caught early rather than silently misbehaving mid-run.
- parameters(with_callbacks=True, recurse: bool = True)[source]#
Override to route through the filtered
named_parametersimplementation.
- predict_step(batch, batch_idx)[source]#
Run the forward pass for a single prediction batch.
Passes
stage="predict"toforwardso forward functions can omit loss computation and return only inference outputs (e.g., embeddings). Used byTrainer.predict()for large-scale feature extraction without a label set.- Parameters:
batch – Input batch dict from the prediction dataloader.
batch_idx – Index of the current batch within the epoch.
- Returns:
Output dict returned by
forward.- Return type:
- setup(stage: str) None[source]#
Lightning setup hook — reject the legacy FSDP1 strategy.
FSDP1 (
Trainer(strategy="fsdp")) uses a flat training-state machine that asserts a single forward/backward per step, which breaks the multi-forward methods that are this library’s bread and butter (DINO, I-JEPA, LeJEPA, …). FSDP2 (strategy="fsdp2") has no such restriction, so we fail fast with a clear redirect.- Parameters:
stage – The Lightning stage (
"fit"/"validate"/"test"/"predict").
- test_step(batch, batch_idx)[source]#
Run the forward pass for a single test batch.
Mirrors
validation_stepbut passesstage="test"toforward, allowing forward functions to distinguish test-time behaviour if needed. The returned dict is forwarded to Lightning’son_test_batch_endcallback hooks.- Parameters:
batch – Input batch dict from the test dataloader.
batch_idx – Index of the current batch within the epoch.
- Returns:
Output dict returned by
forward.- Return type:
- training_step(batch, batch_idx)[source]#
Run one training step with manual optimization across all configured optimizers.
Calls
forward(batch, stage="fit")to obtain astatedict, then performs a singlemanual_backwardonstate["loss"]. Each optimizer steps only when its frequency boundary is met ((batch_idx + 1) % frequency == 0). Gradient clipping is applied per-optimizer using either the per-optimizer override or the Trainer’sgradient_clip_val. Learning rate is logged ashparams/lr_{name}after each step.zero_gradis called only on optimizers that actually stepped this iteration.- Parameters:
batch – Input batch dict from the training dataloader. Must be a
dict— a non-dict batch raisesValueError.batch_idx – Index of the current batch within the epoch. Injected into the batch dict as
batch["batch_idx"]before forwarding.
- Returns:
- The
statedict returned byforward, passed unchanged to Lightning’s callback hooks (
on_train_batch_end).
- The
- Return type:
- validation_step(batch, batch_idx)[source]#
Run the forward pass for a single validation batch.
Calls
forward(batch, stage="validate")with gradients disabled (Lightning handlestorch.no_grad()). The returned dict is passed to every registered callback viaon_validation_batch_end, making all keys — including"embedding"and"label"— available toOnlineProbe,OnlineKNN,RankMe, and similar evaluation callbacks without any extra wiring.- Parameters:
batch – Input batch dict from the validation dataloader.
batch_idx – Index of the current batch within the epoch.
- Returns:
Output dict returned by
forward.- Return type: