Source code for stable_datasets.schema

"""Feature and metadata schema definitions.

Each feature type maps itself to a PyArrow type for Arrow IPC serialization.
"""

from __future__ import annotations

from collections import OrderedDict
from collections.abc import Iterable, Iterator, Mapping
from collections.abc import Sequence as ABCSequence
from dataclasses import dataclass, field
from types import MappingProxyType
from typing import Any, Literal, NewType, Protocol

import pyarrow as pa

from .features import Array3D, ClassLabel, FeatureType, Image, Sequence, Value, Video, VideoRef


__all__ = [
    "Array3D",
    "BuilderConfig",
    "ClassLabel",
    "DatasetInfo",
    "DatasetSource",
    "DownloadInfo",
    "FeatureType",
    "Features",
    "Image",
    "Sequence",
    "URL",
    "Value",
    "Version",
    "Video",
    "VideoDecodeConfig",
    "VideoDecodeFn",
    "VideoDecodeFnBatched",
    "VideoRef",
    "collect_dataset_citations",
]


[docs] class Version: """Semantic version string (``major.minor.patch``).""" def __init__(self, version_str: str): parts = version_str.split(".") if len(parts) != 3: raise ValueError(f"Version string must be 'major.minor.patch', got '{version_str}'") self.major, self.minor, self.patch = int(parts[0]), int(parts[1]), int(parts[2]) self._str = version_str def __str__(self) -> str: return self._str def __repr__(self) -> str: return f"Version('{self._str}')" def __eq__(self, other: object) -> bool: if isinstance(other, Version): return (self.major, self.minor, self.patch) == (other.major, other.minor, other.patch) return NotImplemented def __hash__(self) -> int: return hash((self.major, self.minor, self.patch))
[docs] @dataclass class DownloadInfo: """Download source metadata for one raw asset. ``url`` is attempted first. Any ``fallbacks`` are tried in order if the primary URL fails. """ url: str fallbacks: list[str] = field(default_factory=list) checksum: str | None = None filename: str | None = None def __post_init__(self): if not isinstance(self.url, str) or not self.url: raise TypeError("DownloadInfo.url must be a non-empty string.") if not isinstance(self.fallbacks, list) or not all(isinstance(url, str) and url for url in self.fallbacks): raise TypeError("DownloadInfo.fallbacks must be a list of non-empty strings.") if self.checksum is not None and not isinstance(self.checksum, str): raise TypeError("DownloadInfo.checksum must be a string when provided.") if self.filename is not None and not isinstance(self.filename, str): raise TypeError("DownloadInfo.filename must be a string when provided.")
[docs] def all_urls(self) -> list[str]: return [self.url, *self.fallbacks]
URL = NewType("URL", str)
[docs] @dataclass(frozen=True) class DatasetSource(Mapping[str, object]): """Typed source and download metadata for one dataset builder.""" homepage: URL | str assets: dict[str, DownloadInfo | str] citation: str license: str = "" checksums: dict[str, str] | None = None def __post_init__(self): if not isinstance(self.homepage, str) or not self.homepage: raise TypeError("DatasetSource.homepage must be a non-empty string.") if not isinstance(self.citation, str) or not self.citation: raise TypeError("DatasetSource.citation must be a non-empty string.") if not isinstance(self.license, str): raise TypeError("DatasetSource.license must be a string.") if not isinstance(self.assets, Mapping): raise TypeError("DatasetSource.assets must be a mapping.") normalized_assets = {} for key, value in self.assets.items(): if not isinstance(key, str) or not key: raise TypeError("DatasetSource asset keys must be non-empty strings.") if isinstance(value, str): normalized_assets[key] = DownloadInfo(url=value) elif isinstance(value, DownloadInfo): normalized_assets[key] = value else: raise TypeError( f"DatasetSource.assets['{key}'] must be a URL string or DownloadInfo, got {type(value).__name__}." ) normalized_checksums = None if self.checksums is not None: if not isinstance(self.checksums, Mapping): raise TypeError("DatasetSource.checksums must be a mapping when provided.") normalized_checksums = {} for key, value in self.checksums.items(): if not isinstance(key, str) or not key: raise TypeError("DatasetSource.checksums keys must be non-empty strings.") if not isinstance(value, str) or not value: raise TypeError("DatasetSource.checksums values must be non-empty strings.") normalized_checksums[key] = value object.__setattr__(self, "homepage", str(self.homepage)) object.__setattr__(self, "assets", MappingProxyType(normalized_assets)) object.__setattr__( self, "checksums", None if normalized_checksums is None else MappingProxyType(normalized_checksums), ) def __getitem__(self, key: str) -> object: if key == "homepage": return self.homepage if key == "assets": return self.assets if key == "citation": return self.citation if key == "license": return self.license if key == "checksums": return self.checksums raise KeyError(key) def __iter__(self) -> Iterator[str]: yield "homepage" yield "assets" yield "citation" if self.license: yield "license" if self.checksums is not None: yield "checksums" def __len__(self) -> int: return 3 + int(bool(self.license)) + int(self.checksums is not None)
[docs] def get(self, key: str, default=None): try: return self[key] except KeyError: return default
[docs] def collect_dataset_citations(sources: Iterable[DatasetSource | Mapping[str, object]]) -> list[str]: """Collect unique citation strings in stable first-seen order.""" citations = [] seen = set() for source in sources: citation = source["citation"] if isinstance(source, Mapping) else source.citation if citation not in seen: seen.add(citation) citations.append(citation) return citations
[docs] class VideoDecodeFn(Protocol): """Per-sample video decode callback.""" def __call__( self, ref: VideoRef, config: VideoDecodeConfig, *, row: Mapping[str, Any] | None = None, sample_index: int | None = None, ) -> Any: ...
[docs] class VideoDecodeFnBatched(Protocol): """Batched video decode callback used by ``StableDataset.__getitems__``.""" def __call__( self, refs: ABCSequence[VideoRef], config: VideoDecodeConfig, *, rows: ABCSequence[Mapping[str, Any]] | None = None, sample_indices: ABCSequence[int] | None = None, ) -> ABCSequence[Any]: ...
[docs] @dataclass(frozen=True) class VideoDecodeConfig: """Read-time video decode configuration. This is retrieval policy only: it does not affect cache construction, cache fingerprints, or the persisted schema. """ num_frames: int column: str = "video" sampling: Literal["uniform", "random", "center", "start"] = "uniform" frame_stride: int = 1 decoder: Literal["torchcodec", "decord", "cv2"] = "torchcodec" output: Literal["torch", "numpy"] = "torch" layout: Literal["TCHW", "CTHW", "THWC"] = "TCHW" dtype: Literal["float32", "uint8"] = "float32" scale: Literal["zero_one", "none"] = "zero_one" resize: int | tuple[int, int] | None = None crop: Literal["none", "center", "random"] = "none" pad: Literal["error", "repeat_last", "loop"] = "error" seed: int | None = None decode_fn: VideoDecodeFn | None = None decode_fn_batched: VideoDecodeFnBatched | None = None def __post_init__(self): if not isinstance(self.num_frames, int) or isinstance(self.num_frames, bool): raise TypeError("num_frames must be an int.") if self.num_frames < 1: raise ValueError(f"num_frames must be >= 1, got {self.num_frames}") if not isinstance(self.frame_stride, int) or isinstance(self.frame_stride, bool): raise TypeError("frame_stride must be an int.") if self.frame_stride < 1: raise ValueError(f"frame_stride must be >= 1, got {self.frame_stride}") if not isinstance(self.column, str) or not self.column: raise TypeError("column must be a non-empty string.") _validate_literal("sampling", self.sampling, {"uniform", "random", "center", "start"}) _validate_literal("decoder", self.decoder, {"torchcodec", "decord", "cv2"}) _validate_literal("output", self.output, {"torch", "numpy"}) _validate_literal("layout", self.layout, {"TCHW", "CTHW", "THWC"}) _validate_literal("dtype", self.dtype, {"float32", "uint8"}) _validate_literal("scale", self.scale, {"zero_one", "none"}) _validate_literal("crop", self.crop, {"none", "center", "random"}) _validate_literal("pad", self.pad, {"error", "repeat_last", "loop"}) if self.scale == "zero_one" and self.dtype != "float32": raise ValueError("scale='zero_one' requires dtype='float32'.") if self.resize is not None: if isinstance(self.resize, int): if self.resize < 1: raise ValueError("resize must be >= 1 when provided as an int.") elif ( not isinstance(self.resize, tuple) or len(self.resize) != 2 or any(not isinstance(v, int) or v < 1 for v in self.resize) ): raise ValueError("resize must be an int or a tuple of two positive ints.")
def _validate_literal(name: str, value: str, allowed: set[str]) -> None: if value not in allowed: raise ValueError(f"{name} must be one of {sorted(allowed)}, got {value!r}.")
[docs] class Features(OrderedDict): """Ordered mapping of ``field_name -> FeatureType``. Generates a PyArrow schema via ``.to_arrow_schema()``. """
[docs] def to_arrow_schema(self) -> pa.schema: fields = [] for name, feat in self.items(): if not isinstance(feat, FeatureType): raise TypeError(f"Feature '{name}' must be a FeatureType, got {type(feat).__name__}") metadata = feat.arrow_metadata() fields.append(pa.field(name, feat.to_arrow_type(), metadata=metadata or None)) return pa.schema(fields)
[docs] def fingerprint_data(self) -> str: # Preserve the historical dict-style payload so cache keys stay stable. return repr(dict(self))
[docs] @dataclass class DatasetInfo: """Metadata container for a dataset (description, features, citation, etc.).""" features: Features description: str = "" supervised_keys: tuple | None = None homepage: str = "" citation: str = "" license: str = "" config_name: str = ""
[docs] @dataclass class BuilderConfig: """Base config for multi-variant datasets.""" name: str = "default" version: Version | None = None description: str = ""