"""Core feature descriptors shared across modalities."""
from __future__ import annotations
from pathlib import Path
import pyarrow as pa
[docs]
class FeatureType:
"""Base class for feature type descriptors."""
[docs]
def to_arrow_type(self) -> pa.DataType:
raise NotImplementedError
[docs]
def encode(self, value, *, cache_dir: Path | None = None):
if hasattr(value, "item"):
return value.item()
return value
[docs]
def fingerprint_data(self) -> str:
return repr(self)
[docs]
class Value(FeatureType):
"""Scalar value type. Maps dtype strings to PyArrow types."""
_DTYPE_MAP: dict[str, pa.DataType] = {
"int8": pa.int8(),
"int16": pa.int16(),
"int32": pa.int32(),
"int64": pa.int64(),
"uint8": pa.uint8(),
"uint16": pa.uint16(),
"uint32": pa.uint32(),
"uint64": pa.uint64(),
"float16": pa.float16(),
"float32": pa.float32(),
"float64": pa.float64(),
"bool": pa.bool_(),
"string": pa.string(),
"binary": pa.binary(),
}
def __init__(self, dtype: str):
if dtype not in self._DTYPE_MAP:
raise ValueError(f"Unknown dtype '{dtype}'. Supported: {list(self._DTYPE_MAP)}")
self.dtype = dtype
[docs]
def to_arrow_type(self) -> pa.DataType:
return self._DTYPE_MAP[self.dtype]
def __repr__(self) -> str:
return f"Value('{self.dtype}')"
[docs]
class ClassLabel(FeatureType):
"""Categorical label with name-to-int mapping."""
def __init__(self, names: list[str] | None = None, num_classes: int | None = None):
if names is not None:
self.names: list[str] = list(names)
self.num_classes: int = len(names)
elif num_classes is not None:
self.num_classes = num_classes
self.names = [str(i) for i in range(num_classes)]
else:
raise ValueError("ClassLabel requires either 'names' or 'num_classes'")
self._str2int: dict[str, int] = {n: i for i, n in enumerate(self.names)}
self._int2str: dict[int, str] = dict(enumerate(self.names))
[docs]
def str2int(self, name: str) -> int:
return self._str2int[name]
[docs]
def int2str(self, idx: int) -> str:
return self._int2str[idx]
[docs]
def to_arrow_type(self) -> pa.DataType:
return pa.int64()
[docs]
def encode(self, value, *, cache_dir: Path | None = None):
if isinstance(value, str):
return self.str2int(value)
if hasattr(value, "item"):
return value.item()
return value
def __repr__(self) -> str:
if len(self.names) <= 5:
return f"ClassLabel(names={self.names})"
return f"ClassLabel(num_classes={self.num_classes})"
[docs]
class Sequence(FeatureType):
"""Variable-length list of a sub-feature."""
def __init__(self, feature: FeatureType):
self.feature = feature
[docs]
def to_arrow_type(self) -> pa.DataType:
return pa.list_(self.feature.to_arrow_type())
[docs]
def encode(self, value, *, cache_dir: Path | None = None):
if value is None:
return None
if hasattr(value, "tolist"):
return value.tolist()
return list(value)
def __repr__(self) -> str:
return f"Sequence({self.feature!r})"