Source code for stable_pretraining.callbacks.unused_parameters
from typing import Dict, List
import torch
from torch import nn
from lightning.pytorch.callbacks import Callback
from loguru import logger
[docs]
class LogUnusedParametersOnce(Callback):
"""Lightning callback that logs parameters which do NOT receive gradients.
- Registers hooks on all leaf parameters (requires_grad=True).
- After the first backward pass, logs unused parameters via loguru.
- Removes all hooks and disables itself for the rest of training.
Works with both automatic and manual optimization.
"""
def __init__(self, verbose: bool = True):
super().__init__()
self._hooks: List[torch.utils.hooks.RemovableHandle] = []
self._used_flags: Dict[nn.Parameter, bool] = {}
self._enabled: bool = True
self._verbose = verbose
self._backward_called: bool = False
def _register_hooks(self, model: nn.Module):
"""Attach hooks to all leaf parameters that require gradient."""
assert not self._hooks, "Hooks already registered"
self._backward_called = False
for name, p in model.named_parameters():
if not p.requires_grad:
continue
if not p.is_leaf:
continue
self._used_flags[p] = False
def make_hook(param):
def hook(grad):
self._used_flags[param] = True
self._backward_called = True
return hook
h = p.register_hook(make_hook(p))
self._hooks.append(h)
if self._verbose:
logger.info(
f"[LogUnusedParametersOnce] Registered hooks on "
f"{len(self._used_flags)} leaf parameters."
)
def _remove_hooks(self):
"""Remove all hooks and clear state."""
for h in self._hooks:
h.remove()
self._hooks.clear()
self._used_flags.clear()
def _report_and_disable(self, pl_module: nn.Module):
"""Report unused parameters to loguru and disable further tracking."""
name_by_param = {p: n for n, p in pl_module.named_parameters()}
unused_names = [
name_by_param[p] for p, used in self._used_flags.items() if not used
]
if not unused_names:
logger.info(
"[LogUnusedParametersOnce] All tracked parameters received gradients "
"on the first backward pass."
)
else:
logger.warning(
"[LogUnusedParametersOnce] The following parameters did NOT receive "
"gradients on the first backward pass (potentially causing "
"Lightning's 'unused parameters' error):"
)
for name in unused_names:
logger.warning(f" - {name}")
self._remove_hooks()
self._enabled = False
if self._verbose:
logger.info("[LogUnusedParametersOnce] Hooks removed, callback disabled.")
[docs]
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
"""Register hooks right before the first training batch starts."""
if not self._enabled:
return
if trainer.global_step == 0 and batch_idx == 0:
self._remove_hooks()
self._used_flags.clear()
self._register_hooks(pl_module)
[docs]
def on_after_backward(self, trainer, pl_module):
"""After backward pass, report unused params (automatic optimization)."""
if not self._enabled:
return
self._report_and_disable(pl_module)
[docs]
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
"""Fallback for manual optimization - check after first batch completes."""
if not self._enabled:
return
# If hooks are still registered, on_after_backward wasn't called (manual optimization)
if len(self._hooks) == 0:
return
if not self._backward_called:
logger.warning(
"[LogUnusedParametersOnce] No gradient hooks fired during the first "
"training step. This likely means backward() was never called. "
"Cannot verify unused parameters."
)
self._remove_hooks()
self._enabled = False
if self._verbose:
logger.info(
"[LogUnusedParametersOnce] Hooks removed, callback disabled."
)
else:
self._report_and_disable(pl_module)