"""Stdlib HTTP server for the spt-web viewer.
Routes:
GET / → assets/index.html
GET /assets/<file> → static asset
GET /api/runs → JSON list of runs (sidecar summaries)
GET /api/scan-status → JSON {phase, total, done, initial_done}
GET /api/metrics?run_id=… → JSON sparse metrics (warm-cache friendly)
GET /api/metrics-stream?run_id=… → NDJSON chunks streamed during CSV parse
GET /api/logs?run_id=… → JSON list of available .out/.err streams
GET /api/log-content?run_id=…&stream_id=… → text/plain (last ~4 MiB)
GET /api/stream → Server-Sent Events stream of update deltas
ThreadingHTTPServer spawns a thread per request; SSE handlers hold a
thread for the connection lifetime, which is fine for a local viewer.
"""
from __future__ import annotations
import json
import math
import mimetypes
import queue
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from pathlib import Path
from typing import Any, Optional
from urllib.parse import parse_qs, urlparse
from .scan import RunScanner
def _sanitize_for_json(obj: Any) -> Any:
"""Walk *obj* and replace non-finite floats (``NaN``, ``±Inf``) with ``None``.
Python's ``json.dumps`` emits ``NaN``/``Infinity`` as bare tokens, which
:func:`JSON.parse` in browsers rejects with ``Unexpected token 'N'``.
Training metrics frequently contain non-finite values (early exploding
losses, intentional ``inf`` upper bounds, etc.) so we need to emit valid
JSON. ``None`` round-trips to ``null`` in JS.
"""
if isinstance(obj, float):
if not math.isfinite(obj):
return None
return obj
if isinstance(obj, dict):
return {k: _sanitize_for_json(v) for k, v in obj.items()}
if isinstance(obj, (list, tuple)):
return [_sanitize_for_json(v) for v in obj]
return obj
def _safe_dumps(obj: Any) -> str:
""":func:`json.dumps` that produces valid JSON for non-finite floats."""
return json.dumps(_sanitize_for_json(obj))
ASSETS_DIR = (Path(__file__).parent / "assets").resolve()
class _Handler(BaseHTTPRequestHandler):
server_version = "spt-web/0.1"
scanner: RunScanner = None # type: ignore[assignment]
def log_message(self, format: str, *args: Any) -> None: # noqa: A002
# Default logging is too noisy with SSE pings.
pass
def do_GET(self) -> None: # noqa: N802
try:
url = urlparse(self.path)
path, qs = url.path, parse_qs(url.query)
if path in ("/", "/index.html"):
self._serve_asset("index.html", "text/html; charset=utf-8")
elif path.startswith("/assets/"):
self._serve_asset(path[len("/assets/") :])
elif path == "/api/runs":
self._serve_json(self.scanner.runs_json())
elif path == "/api/scan-status":
self._serve_json(self.scanner.progress_json())
elif path == "/api/metrics":
run_id = qs.get("run_id", [None])[0]
if not run_id:
self._serve_json({"error": "missing run_id"}, 400)
return
# Pre-serialised + cached path: avoids re-running json.dumps
# over a multi-MiB metrics dict on every request.
body = self.scanner.metrics_json_bytes(run_id)
if body is None:
self._serve_json({"error": "run not found"}, 404)
return
self.send_response(200)
self.send_header("Content-Type", "application/json")
self.send_header("Content-Length", str(len(body)))
self.send_header("Cache-Control", "no-store")
self.end_headers()
self.wfile.write(body)
elif path == "/api/metrics-stream":
run_id = qs.get("run_id", [None])[0]
if not run_id:
self._serve_json({"error": "missing run_id"}, 400)
return
self._serve_metrics_stream(run_id)
elif path == "/api/media":
run_id = qs.get("run_id", [None])[0]
if not run_id:
self._serve_json({"error": "missing run_id"}, 400)
return
data = self.scanner.media_json(run_id)
if data is None:
self._serve_json({"error": "run not found"}, 404)
return
self._serve_json(data)
elif path == "/api/logs":
run_id = qs.get("run_id", [None])[0]
if not run_id:
self._serve_json({"error": "missing run_id"}, 400)
return
data = self.scanner.logs_index(run_id)
if data is None:
self._serve_json({"error": "run not found"}, 404)
return
self._serve_json(data)
elif path == "/api/log-content":
run_id = qs.get("run_id", [None])[0]
stream_id = qs.get("stream_id", [None])[0]
if not run_id or not stream_id:
self.send_error(400, "missing run_id or stream_id")
return
body = self.scanner.log_content(run_id, stream_id)
if body is None:
self.send_error(404, "Not Found")
return
self.send_response(200)
self.send_header("Content-Type", "text/plain; charset=utf-8")
self.send_header("Content-Length", str(len(body)))
self.send_header("Cache-Control", "no-store")
self.end_headers()
self.wfile.write(body)
elif path == "/api/media-file":
run_id = qs.get("run_id", [None])[0]
rel = qs.get("path", [None])[0]
if not run_id or not rel:
self.send_error(400, "missing run_id or path")
return
target = self.scanner.media_file_path(run_id, rel)
if target is None:
self.send_error(404, "Not Found")
return
self._serve_file(target)
elif path == "/api/stream":
self._serve_sse()
elif path == "/favicon.ico":
self.send_response(204)
self.end_headers()
else:
self.send_error(404, "Not Found")
except (BrokenPipeError, ConnectionResetError):
pass
# ---- helpers ----
def _serve_asset(self, name: str, ctype: Optional[str] = None) -> None:
# Path traversal guard.
target = (ASSETS_DIR / name).resolve()
if not str(target).startswith(str(ASSETS_DIR)):
self.send_error(403, "Forbidden")
return
if not target.is_file():
self.send_error(404, "Not Found")
return
if ctype is None:
guessed, _ = mimetypes.guess_type(name)
ctype = guessed or "application/octet-stream"
data = target.read_bytes()
self.send_response(200)
self.send_header("Content-Type", ctype)
self.send_header("Content-Length", str(len(data)))
self.send_header("Cache-Control", "no-store")
self.end_headers()
self.wfile.write(data)
def _serve_file(self, target: Path) -> None:
"""Send a file's bytes with a guessed Content-Type and a long cache.
Media files are content-addressed (path includes the step) so they
don't change once written — safe to cache aggressively.
"""
ctype, _ = mimetypes.guess_type(target.name)
ctype = ctype or "application/octet-stream"
size = target.stat().st_size
self.send_response(200)
self.send_header("Content-Type", ctype)
self.send_header("Content-Length", str(size))
self.send_header("Cache-Control", "public, max-age=86400, immutable")
self.end_headers()
with target.open("rb") as f:
# Stream in chunks so very large videos don't blow up memory.
while True:
chunk = f.read(64 * 1024)
if not chunk:
break
self.wfile.write(chunk)
def _serve_json(self, obj: Any, status: int = 200) -> None:
data = _safe_dumps(obj).encode("utf-8")
self.send_response(status)
self.send_header("Content-Type", "application/json")
self.send_header("Content-Length", str(len(data)))
self.send_header("Cache-Control", "no-store")
self.end_headers()
self.wfile.write(data)
def _serve_metrics_stream(self, run_id: str) -> None:
"""Stream metrics as NDJSON over HTTP/1.1 chunked transfer-encoding.
One JSON object per line; each object is either a metrics chunk
(``{"chunk": N, "metrics": {...}}``) or the terminal ``{"done": true}``.
Browsers consume this with ``fetch().body.getReader()`` and progressively
merge chunks into the chart, so the user sees the first points within
a few hundred ms instead of waiting for the whole CSV to parse.
"""
# We probe the iterator's first value before sending headers so a 404
# (run not found) can still be returned cleanly.
gen = self.scanner.metrics_stream(run_id)
try:
first = next(gen)
except StopIteration:
first = None
if first is None:
self._serve_json({"error": "run not found"}, 404)
return
self.send_response(200)
self.send_header("Content-Type", "application/x-ndjson; charset=utf-8")
self.send_header("Cache-Control", "no-store")
self.send_header("Transfer-Encoding", "chunked")
# Disable proxy buffering so chunks reach the browser as soon as we
# flush; otherwise reverse proxies may coalesce them.
self.send_header("X-Accel-Buffering", "no")
self.end_headers()
def _write_chunk(line_bytes: bytes) -> None:
# HTTP/1.1 chunked encoding frame: <hex-size>\r\n<data>\r\n
self.wfile.write(f"{len(line_bytes):x}\r\n".encode("ascii"))
self.wfile.write(line_bytes)
self.wfile.write(b"\r\n")
self.wfile.flush()
try:
_write_chunk(_safe_dumps(first).encode("utf-8") + b"\n")
for item in gen:
_write_chunk(_safe_dumps(item).encode("utf-8") + b"\n")
# Terminating zero-length chunk.
self.wfile.write(b"0\r\n\r\n")
self.wfile.flush()
except (BrokenPipeError, ConnectionResetError, OSError):
pass
def _serve_sse(self) -> None:
self.send_response(200)
self.send_header("Content-Type", "text/event-stream")
self.send_header("Cache-Control", "no-cache")
self.send_header("Connection", "keep-alive")
# Disable proxy buffering (nginx, etc.) just in case.
self.send_header("X-Accel-Buffering", "no")
self.end_headers()
q = self.scanner.subscribe()
try:
self.wfile.write(b"event: ready\ndata: {}\n\n")
self.wfile.flush()
while True:
try:
event = q.get(timeout=15.0)
except queue.Empty:
self.wfile.write(b": ping\n\n")
self.wfile.flush()
continue
payload = (
f"event: {event['type']}\ndata: {_safe_dumps(event['data'])}\n\n"
)
self.wfile.write(payload.encode("utf-8"))
self.wfile.flush()
except (BrokenPipeError, ConnectionResetError, OSError):
pass
finally:
self.scanner.unsubscribe(q)
class _Server(ThreadingHTTPServer):
daemon_threads = True
allow_reuse_address = True
[docs]
def serve(
directory: Path,
host: str = "127.0.0.1",
port: int = 4242,
poll_interval: float = 1.0,
) -> None:
directory = Path(directory).expanduser().resolve()
if not directory.is_dir():
raise NotADirectoryError(f"{directory} is not a directory")
scanner = RunScanner(directory, poll_interval=poll_interval)
scanner.start()
class Handler(_Handler):
pass
Handler.scanner = scanner
srv = _Server((host, port), Handler)
print(f"[spt web] serving {directory}", flush=True)
print(f"[spt web] http://{host}:{port}", flush=True)
try:
srv.serve_forever()
except KeyboardInterrupt:
print("\n[spt web] shutting down", flush=True)
finally:
scanner.stop()
srv.server_close()