Source code for stable_pretraining.callbacks.hf_models

import shutil
import inspect
from pathlib import Path
from typing import Dict, Any
from loguru import logger
import lightning.pytorch as pl
from lightning.pytorch.callbacks import Callback
from transformers import PreTrainedModel
from .utils import log_header


[docs] class HuggingFaceCheckpointCallback(Callback): """Export HF-compatible checkpoints for PreTrainedModel submodules. Identifies submodules inheriting from Hugging Face's `PreTrainedModel` and exports them into standalone, "zero-knowledge" loadable HF directories. This callback automates the synchronization between Lightning training and the Hugging Face ecosystem, handling weight stripping (removing DDP/Lightning prefixes) and dependency copying. Args: save_dir (str): Root directory where HF models will be exported. Default is "hf_exports". verbose (bool): If True, logs a discovery table at the start of training. Default is True. Example: >>> # Setup your model with a HF submodule >>> class MySystem(pl.LightningModule): ... def __init__(self, config): ... super().__init__() ... self.backbone = MyCustomHFModel(config) # Inherits PreTrainedModel >>> # Add callback to trainer >>> hf_cb = HuggingFaceCheckpointCallback(save_dir="checkpoints/hf_models") >>> trainer = pl.Trainer(callbacks=[hf_cb]) >>> trainer.fit(model, dataloader) >>> # Later, load without your source code library: >>> from transformers import AutoModel >>> model = AutoModel.from_pretrained( ... "checkpoints/hf_models/step_5000/backbone", trust_remote_code=True ... ) """ def __init__(self, save_dir: str = "hf_exports", verbose: bool = True): super().__init__() self.save_dir = Path(save_dir) self.verbose = verbose log_header("HuggingFaceCheckpoint") logger.info(f" save_dir: <cyan>{self.save_dir}</cyan>") def _get_hf_submodules( self, pl_module: pl.LightningModule ) -> Dict[str, PreTrainedModel]: """Identifies top-level children that are instances of PreTrainedModel.""" return { name: module for name, module in pl_module.named_children() if isinstance(module, PreTrainedModel) } def _log_discovery_table(self, submodules: Dict[str, PreTrainedModel]): """Renders a diagnostic table of discovered HF submodules using Loguru.""" if not submodules: logger.warning( "! No Hugging Face (PreTrainedModel) submodules found in LightningModule." ) return # Formatting a manual Markdown table for the console header = f"| {'Module Name':<18} | {'Class Type':<22} | {'Config Type':<22} |" sep = f"|{'-' * 20}|{'-' * 24}|{'-' * 24}|" logger.info(" HF Submodule Discovery Summary:") logger.info(sep) logger.info(header) logger.info(sep) for name, mod in submodules.items(): logger.info( f"| {name:<18} | {mod.__class__.__name__:<22} | {mod.config.__class__.__name__:<22} |" ) logger.info(sep)
[docs] def on_train_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule): """Discovers modules and logs the status table once at start.""" submodules = self._get_hf_submodules(pl_module) if self.verbose: self._log_discovery_table(submodules)
def _copy_dependency_tree(self, model: PreTrainedModel, save_path: Path): """Copy source files so relative imports resolve in the export dir. Locates the source files for the model and its immediate neighbors to ensure relative imports (e.g. from .layers import X) resolve. """ model_file = Path(inspect.getfile(model.__class__)).resolve() package_root = model_file.parent # Capture all python scripts in the model's directory # This handles siblings like 'pos_embed.py' or 'swiglu.py' for py_file in package_root.glob("*.py"): shutil.copy2(py_file, save_path / py_file.name)
[docs] def on_save_checkpoint( self, trainer: pl.Trainer, pl_module: pl.LightningModule, checkpoint: Dict[str, Any], ): """Create an atomic HF-compatible export for every found submodule. Triggered by Lightning's checkpointing logic. Only rank 0 performs the export to avoid filesystem race conditions. """ if trainer.global_rank != 0: return step = trainer.global_step hf_step_dir = self.save_dir / f"step_{step}" # Ensure atomic overwrite: Clear directory if it exists if hf_step_dir.exists(): logger.debug(f"Overwriting previous HF directory: {hf_step_dir}") shutil.rmtree(hf_step_dir) hf_step_dir.mkdir(parents=True, exist_ok=True) hf_submodules = self._get_hf_submodules(pl_module) for name, model in hf_submodules.items(): model_save_path = hf_step_dir / name model_save_path.mkdir(parents=True, exist_ok=True) # Extract module/config filenames for the AutoModel map module_fn = Path(inspect.getfile(model.__class__)).stem config_fn = Path(inspect.getfile(model.config.__class__)).stem # Update auto_map so AutoModel knows which .py file contains the classes model.config.auto_map = { "AutoConfig": f"{config_fn}.{model.config.__class__.__name__}", "AutoModel": f"{module_fn}.{model.__class__.__name__}", } # 1. Save Weights & Config.json # Note: Using the model instance (not pl_module) strips all # lightning/DDP prefixes automatically. model.save_pretrained(model_save_path) # 2. Copy code dependencies self._copy_dependency_tree(model, model_save_path) logger.success( f"Exported HF submodule '<green>{name}</green>' at step {step} -> {model_save_path}" )