Source code for stable_pretraining.utils.flops

from torch.utils.flop_counter import FlopCounterMode
from torch.utils._python_dispatch import TorchDispatchMode
from contextlib import contextmanager


[docs] class FLOPBudgetExceeded(Exception): """Exception raised when FLOP budget is exceeded.""" def __init__(self, budget: int, current: int, operation: str = ""): self.budget = budget self.current = current self.operation = operation super().__init__( f"FLOP budget exceeded: {current:,} FLOPs used (budget: {budget:,})" + (f" at operation: {operation}" if operation else "") )
[docs] class BudgetedFlopCounterMode(TorchDispatchMode): """FLOP counter with budget enforcement using composition. Wraps FlopCounterMode for counting, adds budget checking. """ def __init__(self, flop_counter: FlopCounterMode, budget: int): self._flop_counter = flop_counter self.budget = budget self._last_op = "" @property def total_flops(self) -> int: return self._flop_counter.get_total_flops() def __torch_dispatch__(self, func, types, args=(), kwargs=None): self._last_op = getattr(func, "__name__", str(func)) # Call the function (FlopCounterMode will intercept and count) result = func(*args, **(kwargs or {})) # Check budget after FlopCounterMode has counted current = self.total_flops if current > self.budget: raise FLOPBudgetExceeded(self.budget, current, self._last_op) return result
[docs] @contextmanager def flop_budget(budget: int, display: bool = False): """Context manager that counts FLOPs and raises when budget is exceeded. :param budget: Maximum number of FLOPs allowed :param display: Whether to print FLOP breakdown on exit :yields: Counter object with .total_flops property Example: try: with flop_budget(1e9) as counter: for i in range(1000): out = model(x) print(f"FLOPs so far: {counter.total_flops:,}") except FLOPBudgetExceeded as e: print(f"Stopped at {e.current:,} FLOPs") """ budget = int(budget) # FlopCounterMode does the counting flop_counter = FlopCounterMode(display=display) # Our mode checks the budget after each op budget_checker = BudgetedFlopCounterMode(flop_counter, budget) # Stack: budget_checker (top) -> flop_counter (bottom) -> actual ops # When op is called: budget_checker intercepts -> calls func -> # flop_counter intercepts and counts -> actual op runs -> # returns to budget_checker which checks budget with flop_counter: with budget_checker: yield budget_checker