Source code for stable_datasets.formatting

"""Formatters that convert Arrow-native values to user-facing types.

Middle layer of the three-layer split
(:class:`StorageBackend` -> :class:`Formatter` -> :class:`StableDataset`).
Formatters consume Arrow values and emit PIL images, numpy arrays, torch
tensors, or raw Python, never touching files or storage themselves.
"""

from __future__ import annotations

from pathlib import Path

from .schema import Features, FeatureType


def _zip_cols_to_rows(cols: dict, n: int) -> list[dict]:
    """Build list of row dicts from column-oriented dict."""
    keys = list(cols.keys())
    return [{k: cols[k][i] for k in keys} for i in range(n)]


def _extract_columns(table) -> dict:
    """Extract Python column lists one column at a time."""
    return {name: table.column(name).to_pylist() for name in table.column_names}


[docs] class Formatter: """Base formatter. Subclasses convert Arrow-native values to user-facing types.""" format_type = "default" def __init__(self, features: Features, decode_images: bool = True, cache_dir: Path | None = None): self.features = features self.decode_images = decode_images self.cache_dir = Path(cache_dir) if cache_dir is not None else None
[docs] def format_row(self, row: dict) -> dict: """Format a single row dict (from backend.get_row).""" return {name: self._format_value(name, value) for name, value in row.items()}
[docs] def format_batch(self, table) -> list[dict]: """Format a batch (from backend.take). Returns list of row dicts. Column-first: extract columns once, decode each column in bulk, then zip into per-row dicts at the end. """ cols = self._format_columns(_extract_columns(table)) return _zip_cols_to_rows(cols, table.num_rows)
def _format_value(self, name: str, value): feat = self.features.get(name) if isinstance(feat, FeatureType): return feat.format( value, format_type=self.format_type, decode_images=self.decode_images, cache_dir=self.cache_dir, ) return value def _format_columns(self, cols: dict) -> dict: return {name: [self._format_value(name, value) for value in values] for name, values in cols.items()}
[docs] class PythonFormatter(Formatter): """Default format: Image -> PIL, Array3D -> numpy, scalars -> Python native.""" format_type = "default"
[docs] class RawFormatter(Formatter): """Raw format: all values as-is from Arrow (bytes for images, bytes for Array3D).""" format_type = "raw" def __init__(self, features: Features, decode_images: bool = False, cache_dir: Path | None = None): super().__init__(features, decode_images=False, cache_dir=cache_dir)
[docs] def format_row(self, row: dict) -> dict: return row
[docs] def format_batch(self, table) -> list[dict]: return _zip_cols_to_rows(_extract_columns(table), table.num_rows)
[docs] class TorchFormatter(Formatter): """Torch format: Image -> CHW float32 tensor, Array3D -> float32 tensor, scalars -> tensors.""" format_type = "torch"
[docs] class NumpyFormatter(Formatter): """Numpy format: Image -> HWC numpy array, rest as-is.""" format_type = "numpy"
[docs] def get_formatter( format_type: str | None, features: Features, decode_images: bool = True, cache_dir: Path | None = None, ) -> Formatter: """Factory for formatter instances.""" if format_type is None: return PythonFormatter(features, decode_images=decode_images, cache_dir=cache_dir) if format_type == "raw": return RawFormatter(features, decode_images=False, cache_dir=cache_dir) if format_type == "torch": return TorchFormatter(features, decode_images=decode_images, cache_dir=cache_dir) if format_type == "numpy": return NumpyFormatter(features, decode_images=decode_images, cache_dir=cache_dir) raise ValueError(f"Unknown format type: {format_type!r}")