Source code for stable_pretraining.registry.query

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""Read-only query API for the filesystem-backed run registry.

Usage::

    import stable_pretraining as spt

    reg = spt.open_registry()  # lazily scans before querying

    runs = reg.query(tag="sweep:12345")
    best = reg.query(tag="sweep:12345", sort_by="summary.val_acc", limit=5)

    df = reg.to_dataframe(tag="resnet50")

``open_registry()`` triggers an incremental scan of
``{cache_dir}/runs/**`` before returning.  A short in-process TTL
short-circuits back-to-back calls so scripts stay snappy.
"""

from __future__ import annotations

import dataclasses
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

from . import _scanner
from ._store import Store


[docs] @dataclasses.dataclass(frozen=True) class RunRecord: """Immutable view of a single training run, hydrated from the cache.""" run_id: str status: str created_at: float updated_at: float alive: bool run_dir: Optional[str] checkpoint_path: Optional[str] config: Dict[str, Any] hparams: Dict[str, Any] summary: Dict[str, Any] tags: List[str] notes: str
[docs] class Registry: """Read-only query interface over the registry cache. Instantiate via :func:`open_registry` rather than directly — the factory runs a lazy filesystem scan first so you don't query stale data. """ def __init__(self, store: Store) -> None: self._store = store # ------------------------------------------------------------------ queries
[docs] def query( self, *, tag: Optional[str] = None, status: Optional[str] = None, alive: Optional[bool] = None, hparams: Optional[Dict[str, Any]] = None, sort_by: Optional[str] = None, descending: bool = True, limit: Optional[int] = None, ) -> List[RunRecord]: """Query runs matching filters. Args: tag: Include runs that carry this tag (uses substring match on the stored JSON tag array). status: Filter by ``status`` column (``running``, ``completed``, ``failed``, ``orphaned``, ``interrupted``). alive: Filter by heartbeat-based liveness. hparams: ``{key: value}`` pairs the flattened hparams must match (AND, client-side). sort_by: Column name or ``summary.<k>`` / ``hparams.<k>`` / ``config.<k>``. descending: Sort order. limit: Max rows. """ rows = self._store.query_runs( tag=tag, status=status, alive=alive, sort_by=sort_by, descending=descending, limit=limit, ) records = [_row_to_record(r) for r in rows] if hparams: records = [ r for r in records if all(r.hparams.get(k) == v for k, v in hparams.items()) ] return records
def get(self, run_id: str) -> Optional[RunRecord]: row = self._store.get_run(run_id) return _row_to_record(row) if row else None
[docs] def to_dataframe(self, **query_kwargs: Any): """Return a DataFrame with flattened ``hparams.*`` / ``summary.*`` cols.""" import pandas as pd records = self.query(**query_kwargs) if not records: return pd.DataFrame() rows: list[Dict[str, Any]] = [] for r in records: row: Dict[str, Any] = { "run_id": r.run_id, "status": r.status, "alive": r.alive, "created_at": r.created_at, "updated_at": r.updated_at, "run_dir": r.run_dir, "checkpoint_path": r.checkpoint_path, "tags": r.tags, "notes": r.notes, } for k, v in (r.hparams or {}).items(): row[f"hparams.{k}"] = v for k, v in (r.summary or {}).items(): row[f"summary.{k}"] = v rows.append(row) return pd.DataFrame(rows)
# ------------------------------------------------------------------ dunder def __len__(self) -> int: return self._store.count() def __getitem__(self, run_id: str) -> RunRecord: rec = self.get(run_id) if rec is None: raise KeyError(run_id) return rec def __repr__(self) -> str: return f"Registry(db_path={self._store.db_path!r}, runs={len(self)})" def close(self) -> None: self._store.close()
# --------------------------------------------------------------------- factory
[docs] def open_registry( db_path: Optional[Union[str, Path]] = None, *, cache_dir: Optional[Union[str, Path]] = None, scan: bool = True, scan_ttl_s: float = 2.0, ) -> Registry: """Open the registry for querying. Args: db_path: Path to the cache DB. Defaults to ``{cache_dir}/registry.db``. cache_dir: Root where runs are stored (``{cache_dir}/runs/...``). Defaults to ``spt.set(cache_dir=...)``. scan: Run an incremental scan before returning, so the cache reflects the current filesystem state. Disable when you know the scan was just done. scan_ttl_s: If another scan happened within this many seconds in the current process, skip. Returns: A read-only :class:`Registry`. """ resolved_cache, resolved_db = _resolve_paths(cache_dir, db_path) if scan: _scanner.scan_for_query( resolved_cache, resolved_db, ttl_s=scan_ttl_s, ) return Registry(Store(resolved_db, readonly=True))
def _resolve_paths( cache_dir: Optional[Union[str, Path]], db_path: Optional[Union[str, Path]], ) -> tuple[Path, Path]: """Resolve ``(cache_dir, db_path)`` from args + global config.""" if cache_dir is None or db_path is None: try: from .._config import get_config cfg_cache = get_config().cache_dir except Exception: cfg_cache = None if cache_dir is None: if cfg_cache is None: raise ValueError( "No cache_dir provided and spt.set(cache_dir=...) is not " "configured. Pass an explicit cache_dir or set it globally." ) cache_dir = cfg_cache cache_dir = Path(cache_dir).expanduser().resolve() if db_path is None: db_path = cache_dir / "registry.db" else: db_path = Path(db_path).expanduser().resolve() return cache_dir, db_path def _row_to_record(d: Dict[str, Any]) -> RunRecord: return RunRecord( run_id=d["run_id"], status=d.get("status", "unknown"), created_at=float(d.get("created_at") or 0.0), updated_at=float(d.get("updated_at") or 0.0), alive=bool(d.get("alive")), run_dir=d.get("run_dir"), checkpoint_path=d.get("checkpoint_path"), config=d.get("config", {}) or {}, hparams=d.get("hparams", {}) or {}, summary=d.get("summary", {}) or {}, tags=d.get("tags", []) or [], notes=d.get("notes", "") or "", )