# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Lightning logger for the filesystem-backed run registry.
:class:`RegistryLogger` is a thin subclass of Lightning's
:class:`~lightning.pytorch.loggers.CSVLogger`. It writes the standard
CSV + hparams artifacts **and** an indexable ``sidecar.json``, a
fast-readable ``summary.json`` (per-metric stats), and a ``heartbeat``
file in the run directory.
Nothing in the training path touches SQLite or a network server:
* ``log_hyperparams`` → CSV hparams + sidecar snapshot.
* ``log_metrics`` → CSV metrics row + per-metric stats accumulator
(last / min / max / count) + heartbeat touch.
* ``save`` → CSV flush + sidecar rewrite + summary.json
rewrite (both atomic). Lightning calls this at
epoch boundaries and on the flush cadence.
* ``finalize`` → terminal status in sidecar + final summary flush.
* ``after_save_checkpoint`` → ``checkpoint_path`` in sidecar.
A separate scanner (see :mod:`stable_pretraining.registry._scanner`)
turns sidecars into a fast-queryable SQLite cache. Deleting that cache
is harmless — rerun ``spt registry scan --full`` to rebuild.
"""
from __future__ import annotations
import json
import os
import time
from pathlib import Path
from typing import Any, Optional, Union
import csv as _csv
from loguru import logger as logging
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.loggers.csv_logs import ExperimentWriter as _LightningCSVWriter
from lightning.pytorch.utilities.rank_zero import rank_zero_only
from . import _sidecar
class _AppendingExperimentWriter(_LightningCSVWriter):
"""CSV writer that preserves any existing ``metrics.csv`` on init.
Lightning's stock ``_ExperimentWriter`` deletes the file in
``_check_log_dir_exists`` whenever the log dir is non-empty. That makes
SLURM-preempt-and-requeue runs lose all prior training history because
the resumed process re-creates the writer, which truncates the existing
file before the first append.
This subclass skips the deletion. To avoid the parent's ``new_keys``
detection from rewriting the file with a header collision on first save,
we also bootstrap ``metrics_keys`` from the existing CSV header — so
the parent only triggers a header-rewrite when the schema *actually*
changes (e.g. a brand-new metric appears mid-run), not just because
its in-memory ``metrics_keys`` is empty after a fresh process start.
"""
def _check_log_dir_exists(self) -> None: # type: ignore[override]
# Intentional no-op: do not delete prior metrics.csv on resume.
return
def __init__(self, log_dir: str) -> None:
super().__init__(log_dir=log_dir)
# Bootstrap metrics_keys from existing header (if any) so the parent's
# `new_keys = current_keys - metrics_keys` doesn't mistake a fresh
# process start for a schema change.
try:
if self._fs.isfile(self.metrics_file_path):
with self._fs.open(self.metrics_file_path, "r", newline="") as f:
reader = _csv.reader(f)
header = next(reader, None)
if header:
self.metrics_keys = list(header)
except Exception:
# Bootstrap is best-effort; if it fails the parent's rewrite path
# will still preserve old rows via _rewrite_with_new_header.
pass
def _record_new_keys(self) -> set: # type: ignore[override]
"""Append new keys to ``metrics_keys`` *without sorting*.
Lightning's parent calls ``self.metrics_keys.sort()`` after each
update, which silently reorders columns relative to the existing
on-disk CSV header. When a resumed process appends rows in the
sorted order while the file's header retains insertion order, the
column values get scrambled. Preserving insertion order keeps
appended rows aligned with the original header.
"""
current_keys = set().union(*self.metrics)
new_keys = current_keys - set(self.metrics_keys)
# Append in a stable order (sorted among the new keys only) so two
# appends in the same process are deterministic, but DON'T touch
# the existing prefix.
for k in sorted(new_keys):
self.metrics_keys.append(k)
return new_keys
[docs]
class RegistryLogger(CSVLogger):
"""CSV logger with a filesystem-indexable sidecar.
The sidecar is an atomically-rewritten JSON file that captures the
run's hparams, latest metric values (``summary``), status, and
checkpoint path. It is the source of truth for the registry
scanner.
Args:
run_dir: Directory this run writes to. CSV logs,
``sidecar.json`` and ``heartbeat`` all live here.
run_id: Unique identifier for this run (typically the SLURM job
id or a deterministic hash). Used as the primary key in
the registry cache and as the CSV version component.
tags: Free-form string tags for grouping runs (e.g. model
architecture, experiment name, sweep id). Any
``SLURM_ARRAY_JOB_ID`` env var is auto-appended as
``"sweep:<id>"`` for array-job convenience.
notes: Optional free-text description.
flush_logs_every_n_steps: How often the CSV is flushed; the
sidecar is rewritten on the same cadence. The heartbeat
is touched on every ``log_metrics`` call (cheap).
"""
def __init__(
self,
run_dir: Union[str, Path],
run_id: str,
*,
tags: Optional[list[str]] = None,
notes: Optional[str] = None,
flush_logs_every_n_steps: int = 50,
) -> None:
run_dir = Path(run_dir).expanduser().resolve()
run_dir.mkdir(parents=True, exist_ok=True)
# save_dir + name="" + version="" ⇒ CSVLogger.log_dir == run_dir.
# Matches the existing Manager-auto-CSV layout.
super().__init__(
save_dir=str(run_dir),
name="",
version="",
flush_logs_every_n_steps=flush_logs_every_n_steps,
)
self._run_dir = run_dir
self._run_id = str(run_id)
self._tags: list[str] = list(tags or [])
array_job = os.environ.get("SLURM_ARRAY_JOB_ID")
if array_job and f"sweep:{array_job}" not in self._tags:
self._tags.append(f"sweep:{array_job}")
self._notes = notes or ""
self._hparams: dict[str, Any] = {}
self._summary: dict[str, Any] = {}
# Per-metric extended stats for summary.json: each entry is a dict
# with last / min / max / count. Updated incrementally on every
# log_metrics.
self._metric_stats: dict[str, dict[str, Any]] = {}
self._checkpoint_path: Optional[str] = None
self._status = "running"
# Last step / epoch seen via log_metrics — step is used to attach a
# step to media events when log_image / log_video doesn't supply one
# (matches Lightning's WandbLogger behaviour). Epoch is read from
# the metrics dict (Lightning auto-injects an "epoch" key) and is
# surfaced as a top-level field in summary.json.
self._last_step: int = 0
self._last_epoch: int = 0
# Replace Lightning's truncate-on-init writer with our appending one
# so SLURM preempt/requeue cycles don't erase prior training history.
self._experiment = _AppendingExperimentWriter(log_dir=str(run_dir))
# Preserve the first-write timestamp across sidecar rewrites so
# the registry can order runs chronologically regardless of how
# often we flush.
self._created_at: Optional[float] = None
# First-write flag for summary.json — used to log a one-shot info
# line on creation, then debug lines on subsequent rewrites so we
# don't spam every flush.
self._summary_written: bool = False
logging.info(
f"[RegistryLogger] run_id={self._run_id} "
f"run_dir={self._run_dir} — sidecar.json + summary.json + "
"metrics.csv will live here"
)
# -- identity ---------------------------------------------------------------
@property
def run_id(self) -> str:
return self._run_id
@property
def run_dir(self) -> Path:
return self._run_dir
# -- lightning hooks --------------------------------------------------------
[docs]
@rank_zero_only
def log_hyperparams(
self, params: Union[dict[str, Any], Any], *args: Any, **kw: Any
) -> None:
# Persist to CSVLogger's hparams.yaml.
super().log_hyperparams(params, *args, **kw)
self._hparams = _flatten_params(params)
self._write_sidecar()
[docs]
@rank_zero_only
def log_metrics(self, metrics: dict[str, Any], step: Optional[int] = None) -> None:
# CSV-side: write the raw per-step row.
super().log_metrics(metrics, step)
if step is not None:
self._last_step = int(step)
# Lightning auto-injects "epoch" into the metrics dict; surface it
# as a top-level field in summary.json.
if "epoch" in metrics:
ep = _to_scalar(metrics["epoch"])
if ep is not None:
self._last_epoch = int(ep)
for k, v in metrics.items():
scalar = _to_scalar(v)
if scalar is None:
continue
# Sidecar-side: accumulate last-value-per-key summary (kept for
# backward compatibility with the scanner / SQLite cache).
self._summary[k] = scalar
# summary.json-side: extended stats with last/min/max/count.
stats = self._metric_stats.get(k)
if stats is None:
self._metric_stats[k] = {
"last": scalar,
"min": scalar,
"max": scalar,
"count": 1,
}
else:
stats["last"] = scalar
stats["count"] += 1
if scalar < stats["min"]:
stats["min"] = scalar
if scalar > stats["max"]:
stats["max"] = scalar
# Heartbeat: cheap, fire-and-forget; used by the scanner to
# distinguish running / stalled / dead without contacting SLURM.
_sidecar.touch_heartbeat(self._run_dir)
[docs]
@rank_zero_only
def save(self) -> None:
super().save()
self._write_sidecar()
# Lightning calls save() at flush cadence and at epoch boundaries —
# piggyback the summary flush on the same path so summary.json is
# always fresh after each epoch.
self._write_summary_safe()
[docs]
@rank_zero_only
def finalize(self, status: str) -> None:
# Map Lightning status strings to our canonical vocabulary.
self._status = {"success": "completed", "failed": "failed"}.get(status, status)
# Parent writes CSVs. We don't call super().finalize first
# because _experiment may be None on rank-zero callers that
# never logged — super() handles that no-op correctly.
super().finalize(status)
self._write_sidecar()
self._write_summary_safe()
[docs]
def after_save_checkpoint(self, checkpoint_callback: Any) -> None:
# This callback fires on every rank; we gate on rank_zero via
# the helper write (which is rank-zero-only upstream).
path = (
getattr(checkpoint_callback, "best_model_path", None)
or getattr(checkpoint_callback, "last_model_path", None)
or None
)
if path:
self._checkpoint_path = str(path)
self._write_sidecar_safe()
# -- media (images / videos) -----------------------------------------------
[docs]
@rank_zero_only
def log_image(
self,
key: str,
images: list,
step: Optional[int] = None,
caption: Optional[list] = None,
**_: Any,
) -> None:
"""Save images under ``{run_dir}/media/<safe_tag>/<step>_<i>.png``.
Compatible with Lightning's :class:`WandbLogger.log_image` signature,
so existing callbacks that gate on ``hasattr(logger, "log_image")``
will start writing media to disk without code changes.
Accepts numpy arrays (HWC or CHW, uint8 or float[0,1]), PIL images,
torch tensors, or paths to existing files. Each entry is also
appended to ``media.jsonl`` so the registry / web viewer can index
events without walking the filesystem.
"""
s = self._resolve_step(step)
media_dir = self._media_dir(key)
media_dir.mkdir(parents=True, exist_ok=True)
cap = list(caption) if caption else []
for i, img in enumerate(images):
dst = media_dir / f"{s:08d}_{i}.png"
try:
_save_image_to(img, dst)
except Exception as e:
# Don't kill training on a media-save error.
print(f"[RegistryLogger.log_image] failed to save {key}[{i}]: {e}")
continue
self._append_media_event(
{
"step": s,
"tag": key,
"type": "image",
"path": str(dst.relative_to(self._run_dir)),
"caption": cap[i] if i < len(cap) else None,
}
)
[docs]
@rank_zero_only
def log_video(
self,
key: str,
videos: list,
step: Optional[int] = None,
caption: Optional[list] = None,
fps: Optional[int] = None,
format: Optional[str] = None,
**_: Any,
) -> None:
"""Save videos under ``{run_dir}/media/<safe_tag>/<step>_<i>.<ext>``.
Inputs may be filesystem paths to already-encoded files (preferred —
zero re-encoding cost) or raw bytes. The ``fps`` and detected
``format`` are recorded in ``media.jsonl`` so a viewer can play them
back at the right rate.
"""
s = self._resolve_step(step)
media_dir = self._media_dir(key)
media_dir.mkdir(parents=True, exist_ok=True)
cap = list(caption) if caption else []
for i, vid in enumerate(videos):
ext = (format or "mp4").lstrip(".")
if isinstance(vid, (str, Path)):
src_ext = Path(vid).suffix.lstrip(".")
if src_ext:
ext = src_ext
dst = media_dir / f"{s:08d}_{i}.{ext}"
try:
_save_video_to(vid, dst)
except Exception as e:
print(f"[RegistryLogger.log_video] failed to save {key}[{i}]: {e}")
continue
self._append_media_event(
{
"step": s,
"tag": key,
"type": "video",
"path": str(dst.relative_to(self._run_dir)),
"caption": cap[i] if i < len(cap) else None,
"fps": fps,
"format": ext,
}
)
def _resolve_step(self, step: Optional[int]) -> int:
if step is not None:
return int(step)
return int(self._last_step)
def _media_dir(self, key: str) -> Path:
# Replace path separators so the tag becomes a single safe directory.
safe = key.replace("/", "__").replace("\\", "__")
return self._run_dir / "media" / safe
@rank_zero_only
def _append_media_event(self, event: dict[str, Any]) -> None:
"""Append a JSONL line to ``{run_dir}/media.jsonl``.
JSONL (one event per line) keeps writes O(1) — no read-merge-write —
and is robust to crashes (a partially-written line is just discarded
on the next read).
"""
manifest = self._run_dir / "media.jsonl"
try:
with manifest.open("a", encoding="utf-8") as f:
f.write(json.dumps(event) + "\n")
except OSError as e:
print(f"[RegistryLogger] media.jsonl write failed: {e}")
# -- sidecar ----------------------------------------------------------------
@rank_zero_only
def _write_sidecar(self) -> None:
"""Atomically (re)write the sidecar. Let exceptions propagate."""
data = _sidecar.make_sidecar(
run_id=self._run_id,
run_dir=str(self._run_dir),
status=self._status,
created_at=self._created_at,
hparams=self._hparams,
summary=self._summary,
tags=self._tags,
notes=self._notes,
checkpoint_path=self._checkpoint_path,
)
_sidecar.write_sidecar(self._run_dir, data)
if self._created_at is None:
self._created_at = data["created_at"]
@rank_zero_only
def _write_sidecar_safe(self) -> None:
"""Same as :meth:`_write_sidecar` but swallows I/O errors.
Used from callback hooks where a failed write should never take
down a training run.
"""
try:
self._write_sidecar()
except OSError:
pass
# -- summary.json -----------------------------------------------------------
@rank_zero_only
def _write_summary(self) -> None:
"""Atomically (re)write ``summary.json`` with per-metric stats.
Format is intentionally tight (a flat ``{metric: stats_dict}`` map)
so a downstream reader can parse it with a single ``json.load`` and
reach any metric's last/min/max in O(1).
"""
data = {
"schema_version": 1,
"run_id": self._run_id,
"run_dir": str(self._run_dir),
"updated_at": time.time(),
"step": self._last_step,
"epoch": self._last_epoch,
"metrics": dict(self._metric_stats),
}
path = self._run_dir / "summary.json"
_sidecar.atomic_json_write(path, data)
msg = (
f"[RegistryLogger] summary.json flushed "
f"({len(self._metric_stats)} metrics, step={self._last_step}, "
f"epoch={self._last_epoch}) → {path}"
)
if not self._summary_written:
logging.info(msg)
self._summary_written = True
else:
logging.debug(msg)
@rank_zero_only
def _write_summary_safe(self) -> None:
"""Same as :meth:`_write_summary` but swallows I/O errors.
``summary.json`` is auxiliary — a failed write must never take down
a training run.
"""
try:
self._write_summary()
except OSError:
pass
# --------------------------------------------------------------------- helpers
def _save_image_to(img: Any, path: Path) -> None:
"""Persist an image-like value to ``path`` as PNG.
Accepts: file path (bytes-copy), ``PIL.Image``, ``numpy.ndarray``
(HWC or CHW, uint8 or float in [0, 1]), or ``torch.Tensor`` (treated as
numpy after detach/cpu).
"""
if isinstance(img, (str, Path)):
Path(path).write_bytes(Path(img).read_bytes())
return
# PIL.Image — duck-type to avoid a hard dependency on PIL at import time.
if hasattr(img, "save") and hasattr(img, "mode") and hasattr(img, "size"):
img.save(path, format="PNG")
return
# torch.Tensor → numpy
try: # pragma: no cover — torch is optional at logger-import time
import torch # type: ignore
if isinstance(img, torch.Tensor):
img = img.detach().cpu().numpy()
except ImportError:
pass
try:
import numpy as np # type: ignore
except ImportError as e:
raise RuntimeError(
"log_image requires numpy or PIL to save non-path inputs"
) from e
if not isinstance(img, np.ndarray):
raise TypeError(f"unsupported image type for log_image: {type(img)}")
arr = img
if arr.dtype != np.uint8:
arr = (np.clip(arr, 0.0, 1.0) * 255).astype(np.uint8)
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4):
# CHW → HWC heuristic (only flips when the last axis isn't already a channel count).
arr = np.transpose(arr, (1, 2, 0))
if arr.ndim == 3 and arr.shape[-1] == 1:
arr = arr.squeeze(-1)
from PIL import Image # imported lazily so logger import doesn't drag PIL.
Image.fromarray(arr).save(path, format="PNG")
def _save_video_to(vid: Any, path: Path) -> None:
"""Persist a video-like value to ``path``.
Inputs are expected to be already-encoded media — either a filesystem
path or a ``bytes`` blob. We don't re-encode here: callbacks that build
frames in memory should write them out (e.g. via imageio / opencv) and
pass us the resulting path.
"""
if isinstance(vid, (str, Path)):
Path(path).write_bytes(Path(vid).read_bytes())
return
if isinstance(vid, (bytes, bytearray)):
Path(path).write_bytes(bytes(vid))
return
raise TypeError(
f"unsupported video type for log_video: {type(vid)} (pass a path or raw bytes)"
)
def _flatten_params(params: Any) -> dict[str, Any]:
"""Flatten a (possibly nested) hparams object to a flat JSON-safe dict.
Accepts ``DictConfig``, ``Namespace``, dicts, lists, scalars, or
anything. Non-serializable values are stringified so the sidecar
stays round-trippable.
"""
try:
from omegaconf import DictConfig, OmegaConf
if isinstance(params, DictConfig):
params = OmegaConf.to_container(params, resolve=True)
except ImportError:
pass
if not isinstance(params, dict):
params = (
vars(params) if hasattr(params, "__dict__") else {"params": str(params)}
)
out: dict[str, Any] = {}
_flatten(params, "", out)
return out
def _flatten(obj: Any, prefix: str, out: dict[str, Any]) -> None:
if isinstance(obj, dict):
for k, v in obj.items():
_flatten(v, f"{prefix}{k}." if prefix else f"{k}.", out)
elif isinstance(obj, (list, tuple)):
for i, v in enumerate(obj):
_flatten(v, f"{prefix}{i}.", out)
else:
key = prefix.rstrip(".")
try:
json.dumps(obj)
out[key] = obj
except (TypeError, ValueError):
out[key] = str(obj)
def _to_scalar(v: Any) -> Optional[float]:
"""Coerce metric value to a float scalar, or ``None`` if not scalar.
Handles torch Tensors, numpy scalars, int, float, bool. Anything
else (strings, multi-element tensors, etc.) is skipped — we
deliberately keep the summary numeric so downstream tools can
always ``float()`` it.
"""
# Common path: plain float/int.
if isinstance(v, bool):
return float(v)
if isinstance(v, (int, float)):
return float(v)
# Tensor-like with an ``item()`` method and 0-dim shape.
item = getattr(v, "item", None)
if callable(item):
try:
numel = getattr(v, "numel", None)
if callable(numel) and numel() != 1:
return None
return float(item())
except (RuntimeError, ValueError, TypeError):
return None
return None