Source code for stable_pretraining.callbacks.registry

from typing import Optional, Dict, Any
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.callbacks import Callback

_MODULE_REGISTRY: Dict[str, LightningModule] = {}


[docs] def get_module(name: str = "default") -> Optional[LightningModule]: """Retrieve a registered module.""" return _MODULE_REGISTRY.get(name)
[docs] def log(name: str, value: Any, module_name: str = "default", **kwargs) -> None: """Log a metric using the registered module.""" module = _MODULE_REGISTRY.get(module_name) if module is not None: module.log(name, value, **kwargs)
[docs] class ModuleRegistryCallback(Callback): """Callback that automatically registers the module for global logging access.""" def __init__(self, name: str = "default"): self.name = name
[docs] def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: """Register module at the start of any stage (fit, validate, test, predict).""" _MODULE_REGISTRY[self.name] = pl_module
[docs] def teardown( self, trainer: Trainer, pl_module: LightningModule, stage: str ) -> None: """Clean up registry when done.""" _MODULE_REGISTRY.pop(self.name, None)