Source code for stable_pretraining.callbacks.factories

import os
import time

from lightning.pytorch import Callback, LightningModule, Trainer
from lightning.pytorch.callbacks import RichProgressBar as _RichProgressBar

from .checkpoint_sklearn import SklearnCheckpoint, WandbCheckpoint
from .env_info import EnvironmentDumpCallback
from .hf_models import HuggingFaceCheckpointCallback
from .registry import ModuleRegistryCallback
from .trainer_info import LoggingCallback, ModuleSummary, SLURMInfo, TrainerInfo
from .unused_parameters import LogUnusedParametersOnce


[docs] class RichProgressBar(_RichProgressBar): """RichProgressBar with a workaround for a known Rich/Lightning bug. Lightning's ``_stop_progress`` can call ``Live.stop()`` when the live stack is already empty, raising ``IndexError: pop from empty list``. This subclass catches that error so teardown completes cleanly. """ def _stop_progress(self) -> None: try: super()._stop_progress() except IndexError: pass
[docs] class PrintProgressBar(Callback): """Plain-text progress logger for non-interactive environments (SLURM, CI). Prints a one-liner every ``log_every_n_steps`` training batches so that progress shows up in slurm .out files and the wandb Logs tab. """ def __init__(self, log_every_n_steps: int = 50): super().__init__() self.log_every_n_steps = log_every_n_steps self._epoch_start = None
[docs] def on_train_epoch_start( self, trainer: Trainer, pl_module: LightningModule ) -> None: self._epoch_start = time.time()
[docs] def on_train_batch_end( self, trainer: Trainer, pl_module: LightningModule, outputs, batch, batch_idx ) -> None: if (batch_idx + 1) % self.log_every_n_steps != 0: return total = trainer.num_training_batches epoch = trainer.current_epoch max_epochs = trainer.max_epochs or "?" elapsed = time.time() - self._epoch_start if self._epoch_start else 0 it_s = (batch_idx + 1) / elapsed if elapsed > 0 else 0 # Grab metrics on the progress bar metrics = trainer.progress_bar_metrics metrics_str = " | ".join(f"{k}: {v:.4g}" for k, v in metrics.items()) print( f"[Epoch {epoch}/{max_epochs}] " f"step {batch_idx + 1}/{total} " f"({it_s:.1f} it/s)" + (f" | {metrics_str}" if metrics_str else ""), flush=True, )
[docs] def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: elapsed = time.time() - self._epoch_start if self._epoch_start else 0 epoch = trainer.current_epoch max_epochs = trainer.max_epochs or "?" metrics = trainer.progress_bar_metrics metrics_str = " | ".join(f"{k}: {v:.4g}" for k, v in metrics.items()) print( f"[Epoch {epoch}/{max_epochs}] done in {elapsed:.1f}s" + (f" | {metrics_str}" if metrics_str else ""), flush=True, )
def _make_progress_bar(): """Create a progress bar callback. If stdout is a tty (local shell, interactive srun, etc.), uses Rich for a nice live display. Otherwise (sbatch, Hydra multirun, piped output), falls back to a plain-text line printer that shows up in slurm .out files and the wandb Logs tab. """ if os.isatty(1): return RichProgressBar() return PrintProgressBar()
[docs] def default(): """Factory function that returns default callbacks.""" callbacks = [ _make_progress_bar(), ModuleRegistryCallback(), LoggingCallback(), EnvironmentDumpCallback(async_dump=True), TrainerInfo(), SklearnCheckpoint(), WandbCheckpoint(), ModuleSummary(), SLURMInfo(), LogUnusedParametersOnce(), HuggingFaceCheckpointCallback(), ] return callbacks