Source code for stable_pretraining.callbacks.utils

import collections.abc
import copy
import dataclasses
import types
from functools import partial
from typing import Any, Dict, Iterable, Optional, Union

import torch
import torchmetrics
from lightning.pytorch import Callback, LightningModule
from loguru import logger as logging

from ..optim import create_optimizer, create_scheduler, LARS

_HEADER_WIDTH = 50


def get_data_from_batch_or_outputs(
    key: Union[Iterable[str], str],
    batch: Dict[str, Any],
    outputs: Optional[Dict[str, Any]] = None,
    caller_name: str = "Callback",
) -> Optional[Any]:
    """Get data from either outputs or batch dictionary.

    In PyTorch Lightning, the outputs parameter in callbacks contains the return
    value from training_step/validation_step, while batch contains the original
    input. Since forward methods may modify batch in-place but Lightning creates
    a copy for outputs, we need to check both.

    Args:
        key: The key(s) to look for in the dictionaries
        batch: The original batch dictionary
        outputs: The outputs dictionary from training/validation step
        caller_name: Name of the calling function/class for logging

    Returns:
        The data associated with the key, or None if not found
    """
    output_as_list = True
    if type(key) is str:
        key = [key]
        output_as_list = False
    out = []
    for k in key:
        # First check outputs (which contains the forward pass results)
        if outputs is not None and k in outputs:
            out.append(outputs[k])
        elif k in batch:
            out.append(batch[k])
        else:
            msg = (
                f"{caller_name}: Key '{k}' not found in batch or outputs. "
                f"Available batch keys: {list(batch.keys())}, "
                f"Available output keys: {list(outputs.keys()) if outputs else 'None'}"
            )
            logging.warning(msg)
            raise ValueError(msg)
    if output_as_list:
        return out
    return out[0]


def detach_tensors(obj: Any) -> Any:
    """Recursively traverse an object and return an equivalent structure with all torch tensors detached.

    - Preserves structure, types, and shared references.
    - Handles cycles and arbitrary Python objects (including __dict__ and __slots__).
    - Does not mutate the input; only rebuilds containers if needed.
    - torch.nn.Parameter is replaced with a detached Tensor (not Parameter).
    - Optionally supports attrs classes if 'attr' is installed.

    Args:
        obj: The input object (can be arbitrarily nested).

    Returns:
        A new object with all torch tensors detached, or the original object if no tensors found.
    Performance notes:
        - Uses memoization to avoid redundant work and preserve shared/cyclic structure.
        - Avoids unnecessary copies: unchanged subtrees are returned as-is (same id).
        - Shallow-copies objects with __dict__ or __slots__ (does not call __init__).
    """
    memo: Dict[int, Any] = {}
    # Feature-detect attrs support
    try:
        import attr

        _HAS_ATTRS = True
    except ImportError:
        _HAS_ATTRS = False

    def _detach_impl(o: Any) -> Any:
        oid = id(o)
        if oid in memo:
            return memo[oid]
        # Tensors (including Parameter)
        if isinstance(o, torch.Tensor):
            result = o.detach()
            memo[oid] = result
            return result
        # defaultdict: must preserve default_factory and handle cycles
        if isinstance(o, collections.defaultdict):
            result = type(o)(o.default_factory)
            memo[oid] = result
            changed = False
            for k, v in o.items():
                new_v = _detach_impl(v)
                changed = changed or (new_v is not v)
                result[k] = new_v
            # Always return the new result, even if not changed, to ensure correct default_factory and keys
            return result
        # dict/OrderedDict/other Mapping (excluding defaultdict)
        if isinstance(o, collections.abc.Mapping):
            # For custom mapping subclasses, try to preserve type
            result = type(o)()
            memo[oid] = result
            changed = False
            for k, v in o.items():
                new_v = _detach_impl(v)
                changed = changed or (new_v is not v)
                result[k] = new_v
            # For plain dict, if nothing changed, return original
            if not changed and type(o) is dict:
                memo[oid] = o
                return o
            return result
        # Dataclasses (handle frozen and init=False fields)
        if dataclasses.is_dataclass(o) and not isinstance(o, type):
            # Step 1: create a shallow copy via dataclasses.replace (no field overrides)
            try:
                copy_obj = dataclasses.replace(o)
            except Exception:
                # fallback for dataclasses with no fields
                copy_obj = copy.copy(o)
            memo[oid] = copy_obj
            changed = False
            for f in dataclasses.fields(o):
                v = getattr(o, f.name)
                new_v = _detach_impl(v)
                if new_v is not v:
                    object.__setattr__(copy_obj, f.name, new_v)
                    changed = True
            if not changed:
                memo[oid] = o
                return o
            return copy_obj
        # attrs classes (if available)
        if _HAS_ATTRS and attr.has(o) and not isinstance(o, type):
            # Use attr.evolve to create a shallow copy, then set fields
            copy_obj = attr.evolve(o)
            memo[oid] = copy_obj
            changed = False
            for f in attr.fields(type(o)):
                v = getattr(o, f.name)
                new_v = _detach_impl(v)
                if new_v is not v:
                    object.__setattr__(copy_obj, f.name, new_v)
                    changed = True
            if not changed:
                memo[oid] = o
                return o
            return copy_obj
        # Namedtuple (but not plain tuple)
        if isinstance(o, tuple) and hasattr(o, "_fields"):
            values = []
            changed = False
            for v in o:
                new_v = _detach_impl(v)
                changed = changed or (new_v is not v)
                values.append(new_v)
            if not changed:
                memo[oid] = o
                return o
            result = type(o)(*values)
            memo[oid] = result
            return result
        # List
        if isinstance(o, list):
            result = []
            memo[oid] = result
            changed = False
            for v in o:
                new_v = _detach_impl(v)
                changed = changed or (new_v is not v)
                result.append(new_v)
            if not changed:
                memo[oid] = o
                return o
            return result
        # Tuple (not namedtuple)
        if isinstance(o, tuple):
            values = []
            changed = False
            for v in o:
                new_v = _detach_impl(v)
                changed = changed or (new_v is not v)
                values.append(new_v)
            if not changed:
                memo[oid] = o
                return o
            result = tuple(values)
            memo[oid] = result
            return result
        # Set
        if isinstance(o, set):
            result = set()
            memo[oid] = result
            changed = False
            for v in o:
                new_v = _detach_impl(v)
                changed = changed or (new_v is not v)
                result.add(new_v)
            if not changed:
                memo[oid] = o
                return o
            return result
        # Frozenset
        if isinstance(o, frozenset):
            values = []
            changed = False
            for v in o:
                new_v = _detach_impl(v)
                changed = changed or (new_v is not v)
                values.append(new_v)
            if not changed:
                memo[oid] = o
                return o
            result = frozenset(values)
            memo[oid] = result
            return result
        # Generic objects with __dict__ or __slots__
        if hasattr(o, "__dict__") or hasattr(o, "__slots__"):
            result = copy.copy(o)
            memo[oid] = result
            changed = False
            # __dict__ attributes
            if hasattr(result, "__dict__"):
                for k, v in result.__dict__.items():
                    new_v = _detach_impl(v)
                    if new_v is not v:
                        setattr(result, k, new_v)
                        changed = True
            # __slots__ attributes
            if hasattr(result, "__slots__"):
                for slot in result.__slots__:
                    if hasattr(result, slot):
                        v = getattr(result, slot)
                        new_v = _detach_impl(v)
                        if new_v is not v:
                            setattr(result, slot, new_v)
                            changed = True
            if not changed:
                memo[oid] = o
                return o
            return result
        # All other types: return as is
        memo[oid] = o
        return o

    return _detach_impl(obj)


def log_header(name: str, width: int = _HEADER_WIDTH) -> None:
    """Log a unified section header: ``── Name ────────────``."""
    pad = max(width - len(name) - 4, 2)
    logging.info(f"── {name} " + "─" * pad)


# Registry of callbacks whose position in ``trainer.callbacks`` matters.
# Each entry: class name → human-readable ordering rule.
#
# Scope: ONLY include callbacks where two callbacks act in the **same**
# Lightning hook and the order of writes/reads inside that hook matters.
# Lightning runs each hook to completion across all callbacks before moving
# to the next, so producer/consumer pairs split across different hooks
# (e.g., OnlineQueue creates the snapshot in on_validation_epoch_start;
# consumers read it in on_validation_batch_end) are NOT order-sensitive —
# the producer hook finishes for every callback before any consumer hook
# runs. Don't list those here.
ORDER_SENSITIVE_CALLBACKS: Dict[str, str] = {
    "TeacherStudentCallback": (
        "EMA update fires in on_train_batch_end; place AFTER any callback "
        "that reads the teacher's parameters in that same hook"
    ),
    "OnlineProbe": (
        "trains its own probe on the current batch's embeddings inside "
        "on_train_batch_end — place AFTER callbacks that mutate the "
        "embedding in the same hook (e.g., normalization probes)"
    ),
    "OnlineWriter": (
        "writes batch outputs to disk in on_train_batch_end — place LAST "
        "among per-batch callbacks so it captures all mutations"
    ),
    "CleanUpCallback": (
        "deletes files in on_train_end / teardown — must come AFTER any "
        "callback that saves artefacts in the same hook (checkpoint "
        "callbacks, hf_models, etc.)"
    ),
}


def log_callbacks_order(callbacks: Iterable[Callback]) -> None:
    """Log the callback execution order with annotations on order-sensitive ones.

    Lightning runs ``trainer.callbacks`` in registration order. Within a
    single hook, callbacks fire in that order; across hooks, Lightning
    finishes each hook across all callbacks before moving to the next. So
    only same-hook read/write dependencies are order-sensitive (post-backward
    EMA updates, end-of-training cleanup, batch-output writers). This helper
    surfaces the actual order at runtime so misplacements are easy to spot.

    The list of order-sensitive callback class names + their constraints is
    kept in :data:`ORDER_SENSITIVE_CALLBACKS`.
    """
    log_header("Callbacks (in order)")
    callbacks = list(callbacks)
    if not callbacks:
        logging.info("  (none registered)")
        return
    width = len(str(len(callbacks)))
    for i, cb in enumerate(callbacks):
        cls = type(cb).__name__
        rule = ORDER_SENSITIVE_CALLBACKS.get(cls)
        marker = "⚑" if rule is not None else " "
        logging.info(f"  {marker} [{i:>{width}}] {cls}")
        if rule is not None:
            logging.info(f"       └─ order rule: {rule}")
    logging.info(
        "  ⚑ marks order-sensitive callbacks; see AGENTS.md → Callback "
        "ordering for the full rules."
    )


def resolve_verbose(verbose: Optional[bool]) -> bool:
    """Resolve a callback's ``verbose`` flag.

    * ``True`` / ``False`` — honour the explicit choice.
    * ``None`` — derive from the global config: verbose if the global
      log level is INFO or lower (i.e. more detailed).
    """
    if verbose is not None:
        return verbose
    from .._config import get_config, _VALID_LOG_LEVELS

    level = get_config().verbose
    # INFO is index 2; anything <= INFO means "chatty enough for verbose"
    return _VALID_LOG_LEVELS.index(level) <= _VALID_LOG_LEVELS.index("INFO")


class TrainableCallback(Callback):
    """Base callback class with optimizer and scheduler management.

    This base class handles the common logic for callbacks that need their own
    optimizer and scheduler, including automatic inheritance from the main module's
    configuration when not explicitly specified.

    Subclasses should:
    1. Call super().__init__() with appropriate parameters
    2. Store their module configuration in self._module_config
    3. Override configure_model() to create their specific module
    4. Access their module via self.module property after setup
    """

    def __init__(
        self,
        module: LightningModule,
        name: str,
        optimizer: Optional[Union[str, dict, partial, torch.optim.Optimizer]] = None,
        scheduler: Optional[
            Union[str, dict, partial, torch.optim.lr_scheduler.LRScheduler]
        ] = None,
        accumulate_grad_batches: int = 1,
        gradient_clip_val: float = None,
        gradient_clip_algorithm: str = "norm",
    ):
        """Initialize base callback with optimizer/scheduler configuration.

        Args:
            module: spt.Module.
            name: Unique identifier for this callback instance.
            optimizer: Optimizer configuration. If None, uses default LARS.
            scheduler: Scheduler configuration. If None, uses default ConstantLR.
            accumulate_grad_batches: Number of batches to accumulate gradients.
            gradient_clip_val: Value to clip the gradient (default None).
            gradient_clip_algorithm: Algorithm to clip the gradient (default `norm`).
        """
        super().__init__()
        self.name = name
        self.accumulate_grad_batches = accumulate_grad_batches
        self.gradient_clip_val = gradient_clip_val
        self.gradient_clip_algorithm = gradient_clip_algorithm

        # Store configurations
        self._optimizer_config = optimizer
        self._scheduler_config = scheduler
        self._pl_module = module
        self.wrap_configure_model(module)
        self.wrap_configure_optimizers(module)

    def wrap_configure_model(self, pl_module):
        fn = pl_module.configure_model

        def new_configure_model(self, callback=self, fn=fn):
            # Initialize module
            fn()
            module = callback.configure_model(self)
            # Store module in pl_module.callbacks_modules
            logging.info("  storing module in callbacks_modules")
            self.callbacks_modules[callback.name] = module
            # Metrics are optional — not all trainable callbacks expose
            # them (e.g. generative/reconstruction heads whose only output
            # is a loss scalar).
            metrics = getattr(callback, "metrics", None)
            if metrics is not None:
                logging.info("  setting up metrics")
                assert callback.name not in self.callbacks_metrics
                self.callbacks_metrics[callback.name] = format_metrics_as_dict(metrics)

        # Bind the new method to the instance
        logging.info("  wrapping configure_model")
        pl_module.configure_model = types.MethodType(new_configure_model, pl_module)

    def configure_model(self, pl_module: LightningModule) -> torch.nn.Module:
        """Initialize the module for this callback.

        Subclasses must override this method to create their specific module.

        Args:
            pl_module: The Lightning module being trained.

        Returns:
            The initialized module.
        """
        raise NotImplementedError("Subclasses must implement configure_model")

    def wrap_configure_optimizers(self, pl_module):
        fn = pl_module.configure_optimizers

        def new_configure_optimizers(self, callback=self, fn=fn):
            outputs = fn()
            if outputs is None:
                optimizers = []
                schedulers = []
            else:
                optimizers, schedulers = outputs
            # assert callback.name not in self._optimizer_name_to_index
            assert callback.name not in self._optimizer_frequencies
            # assert callback.name not in self._optimizer_names
            assert callback.name not in self._optimizer_gradient_clip_val
            assert callback.name not in self._optimizer_gradient_clip_algorithm
            assert len(optimizers) not in self._optimizer_index_to_name
            self._optimizer_index_to_name[len(optimizers)] = callback.name
            # self._optimizer_name_to_index[callback.name] = len(self._optimizer_names)
            # self._optimizer_names.append(callback.name)
            self._optimizer_frequencies[callback.name] = (
                callback.accumulate_grad_batches
            )
            self._optimizer_gradient_clip_val[callback.name] = (
                callback.gradient_clip_val
            )
            self._optimizer_gradient_clip_algorithm[callback.name] = (
                callback.gradient_clip_algorithm
            )
            optimizers.append(callback.setup_optimizer(self))
            schedulers.append(callback.setup_scheduler(optimizers[-1], self))
            return optimizers, schedulers

        # Bind the new method to the instance
        logging.info("  wrapping configure_optimizers")
        pl_module.configure_optimizers = types.MethodType(
            new_configure_optimizers, pl_module
        )

    def setup_optimizer(self, pl_module: LightningModule) -> None:
        """Initialize optimizer with default LARS if not specified."""
        if self._optimizer_config is None:
            # Use default LARS optimizer for SSL linear probes
            logging.info("  no optimizer given, using default LARS")
            return LARS(
                self.module.parameters(),
                lr=0.1,
                clip_lr=True,
                eta=0.02,
                exclude_bias_n_norm=True,
                weight_decay=0,
            )
        # Use explicitly provided optimizer config. Passing ``named_params``
        # lets ``exclude_bias_norm`` (#368) work whether set per-config or via
        # the global default.
        logging.info("  using explicitly provided optimizer")
        return create_optimizer(
            self.module.parameters(),
            self._optimizer_config,
            named_params=self.module.named_parameters(),
        )

    def setup_scheduler(self, optimizer, pl_module: LightningModule) -> None:
        """Initialize scheduler with default ConstantLR if not specified."""
        if self._scheduler_config is None:
            # Use default ConstantLR scheduler
            logging.info("  no scheduler given, using default ConstantLR")
            return torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0)
        logging.info("  using explicitly provided scheduler")
        return create_scheduler(optimizer, self._scheduler_config, module=pl_module)

    @property
    def module(self):
        """Access module from pl_module.callbacks_modules.

        This property is only accessible after setup() has been called.
        The module is stored centrally in pl_module.callbacks_modules
        to avoid duplication in checkpoints.
        """
        if self._pl_module is None:
            raise AttributeError(
                f"{self.name}: module not accessible before setup(). "
                "The module is initialized during the setup phase."
            )
        return self._pl_module.callbacks_modules[self.name]

    @property
    def state_key(self) -> str:
        """Unique identifier for this callback's state during checkpointing."""
        return f"{self.__class__.__name__}[name={self.name}]"


[docs] class EarlyStopping(torch.nn.Module): """Early stopping mechanism with support for metric milestones and patience. This module provides flexible early stopping capabilities that can halt training based on metric performance. It supports both milestone-based stopping (stop if metric doesn't reach target by specific epochs) and patience-based stopping (stop if metric doesn't improve for N epochs). Args: mode: Optimization direction - 'min' for metrics to minimize (e.g., loss), 'max' for metrics to maximize (e.g., accuracy). milestones: Dict mapping epoch numbers to target metric values. Training stops if targets are not met at specified epochs. metric_name: Name of the metric to monitor if metric is a dict. patience: Number of epochs with no improvement before stopping. Example: >>> early_stop = EarlyStopping(mode="max", milestones={10: 0.8, 20: 0.9}) >>> # Stops if accuracy < 0.8 at epoch 10 or < 0.9 at epoch 20 """ def __init__( self, mode: str = "min", milestones: dict[int, float] = None, metric_name: str = None, patience: int = 10, ): super().__init__() self.mode = mode self.milestones = milestones or {} self.metric_name = metric_name self.patience = patience self.register_buffer("history", torch.zeros(patience)) def should_stop(self, metric, step): if self.metric_name is None: assert type(metric) is not dict else: assert self.metric_name in metric metric = metric[self.metric_name] if step in self.milestones: if self.mode == "min": return metric > self.milestones[step] elif self.mode == "max": return metric < self.milestones[step] return False
def format_metrics_as_dict(metrics): """Formats various metric input formats into a standardized dictionary structure. This utility function handles multiple input formats for metrics and converts them into a consistent ModuleDict structure with separate train and validation metrics. This standardization simplifies metric handling across callbacks. Args: metrics: Can be: - None: Returns empty train and val dicts - Single torchmetrics.Metric: Applied to validation only - Dict with 'train' and 'val' keys: Separated accordingly - Dict of metrics: All applied to validation - List/tuple of metrics: All applied to validation Returns: ModuleDict with '_train' and '_val' keys, each containing metric ModuleDicts. Raises: ValueError: If metrics format is invalid or contains non-torchmetric objects. """ # Handle OmegaConf types from omegaconf import ListConfig, DictConfig if isinstance(metrics, (ListConfig, DictConfig)): import omegaconf metrics = omegaconf.OmegaConf.to_container(metrics, resolve=True) if metrics is None: train = {} eval = {} elif isinstance(metrics, torchmetrics.Metric): train = {} eval = torch.nn.ModuleDict({metrics.__class__.__name__: metrics}) elif type(metrics) is dict and set(metrics.keys()) == set(["train", "val"]): train = {} eval = {} if type(metrics["train"]) in [list, tuple]: for m in metrics["train"]: if not isinstance(m, torchmetrics.Metric): raise ValueError(f"metric {m} is no a torchmetric") train[m.__class__.__name__] = m else: train[metrics["train"].__class__.__name__] = metrics["train"] if type(metrics["val"]) in [list, tuple]: for m in metrics["val"]: if not isinstance(m, torchmetrics.Metric): raise ValueError(f"metric {m} is no a torchmetric") eval[m.__class__.__name__] = m else: eval[metrics["val"].__class__.__name__] = metrics["val"] elif type(metrics) is dict: train = {} for k, v in metrics.items(): assert type(k) is str assert isinstance(v, torchmetrics.Metric) eval = metrics elif type(metrics) in [list, tuple]: train = {} eval = {} for m in metrics: if not isinstance(m, torchmetrics.Metric): raise ValueError(f"metric {m} is no a torchmetric") eval[m.__class__.__name__] = m else: raise ValueError( "metrics can only be a torchmetric of list/tuple of torchmetrics" ) return torch.nn.ModuleDict( { "_train": torch.nn.ModuleDict(train), "_val": torch.nn.ModuleDict(eval), } )