Source code for stable_datasets.backends.lance_video_frames

"""Lance-backed random-access video frame segment layout."""

from __future__ import annotations

import json
from pathlib import Path
from typing import Any

import numpy as np
import pyarrow as pa


_FRAME_SCHEMA = pa.schema(
    [
        ("video_id", pa.int32()),
        ("frame_idx", pa.int32()),
        ("bytes", pa.large_binary()),
    ]
)

_LANCE_CACHE: dict[str, Any] = {}


def _open_dataset(path: str):
    ds = _LANCE_CACHE.get(path)
    if ds is None:
        import lance

        ds = lance.dataset(path)
        _LANCE_CACHE[path] = ds
    return ds


[docs] def reset_worker_state() -> None: """Reset process-local decoder/backend state after a DataLoader fork.""" global _LANCE_CACHE _LANCE_CACHE = {} try: import cv2 cv2.setNumThreads(1) except Exception: pass
def _compute_segment_plan( videos: list[dict], window_length: int, frame_skip: int, hop_size: int, min_video_frames: int | None, ) -> tuple[np.ndarray, np.ndarray, list[int]]: if window_length < 1: raise ValueError(f"window_length must be >= 1, got {window_length}") if frame_skip < 0: raise ValueError(f"frame_skip must be >= 0, got {frame_skip}") if hop_size < 1: raise ValueError(f"hop_size must be >= 1, got {hop_size}") stride = frame_skip + 1 span = (window_length - 1) * stride + 1 min_t = max(span, int(min_video_frames) if min_video_frames is not None else span) seg_vid: list[int] = [] seg_start: list[int] = [] per_video: list[int] = [] for vi, video in enumerate(videos): frames = int(video["T"]) if frames < min_t: per_video.append(0) continue last_start = frames - span n = last_start // hop_size + 1 per_video.append(int(n)) starts = range(0, last_start + 1, hop_size) seg_vid.extend([vi] * n) seg_start.extend(starts) return np.asarray(seg_vid, dtype=np.int64), np.asarray(seg_start, dtype=np.int64), per_video
[docs] class LanceVideoFramesBackend: """StorageBackend for the ``lance-video-frames`` layout. The physical Lance dataset stores one WebP-encoded frame per row. The logical dataset exposes deterministic frame windows as samples. """ prefer_batched_take: bool = False def __init__( self, *, uri: str | Path, window_length: int = 1, frame_skip: int = 0, hop_size: int = 1, min_video_frames: int | None = None, batch_readahead: int = 8, ): self._uri = Path(uri) self._batch_readahead = int(batch_readahead) self.window_length = int(window_length) self.frame_skip = int(frame_skip) self.hop_size = int(hop_size) self._stride = self.frame_skip + 1 self._span = (self.window_length - 1) * self._stride + 1 meta_path = self._uri / "_metadata.json" if not meta_path.exists(): raise FileNotFoundError(f"No metadata file at {meta_path}") self.metadata = json.loads(meta_path.read_text()) self._videos = list(self.metadata.get("videos", [])) if not self._videos: raise ValueError(f"no videos recorded in {meta_path}") self._seg_vid, self._seg_start, self._per_video = _compute_segment_plan( self._videos, self.window_length, self.frame_skip, self.hop_size, min_video_frames, ) if self._seg_vid.size == 0: raise ValueError(f"no valid segments: every video has fewer than span={self._span} frames") self._ds = None @property def _dataset(self): if self._ds is None: self._ds = _open_dataset(str(self._uri)) return self._ds @property def num_rows(self) -> int: return int(self._seg_vid.shape[0]) @property def num_shards(self) -> int: return 1 @property def is_file_backed(self) -> bool: return True @property def cache_dir(self) -> Path: return self._uri @property def schema(self) -> pa.Schema: return _FRAME_SCHEMA @property def table(self) -> pa.Table: return self._dataset.to_table() @property def video_paths(self) -> list[str]: return [video["path"] for video in self._videos] @property def segment_filenames(self) -> list[str]: paths = self.video_paths return [paths[int(video_id)] for video_id in self._seg_vid]
[docs] def segment_filename(self, idx: int) -> str: return self._videos[int(self._seg_vid[idx])]["path"]
[docs] def segment_info(self, idx: int) -> dict: vi = int(self._seg_vid[idx]) start = int(self._seg_start[idx]) return { "video_idx": vi, "filename": self._videos[vi]["path"], "start_frame": start, "frame_indices": [start + j * self._stride for j in range(self.window_length)], }
[docs] def get_row(self, idx: int) -> dict: if idx < 0: idx += self.num_rows if idx < 0 or idx >= self.num_rows: raise IndexError(idx) vi = int(self._seg_vid[idx]) start = int(self._seg_start[idx]) video = self._videos[vi] frame_indices = [start + j * self._stride for j in range(self.window_length)] row0 = int(video["start_row"]) rows = [row0 + frame_idx for frame_idx in frame_indices] blobs = self._dataset.take(rows, columns=["bytes"]).column("bytes").to_pylist() frames = self._decode_blobs(blobs, int(video["H"]), int(video["W"])) sample = { "video": frames, "video_idx": vi, "filename": video["path"], "start_frame": start, "frame_indices": frame_indices, "sample_idx": int(idx), } metadata = video.get("metadata") if isinstance(metadata, dict): sample.update(metadata) return sample
[docs] def take(self, indices: np.ndarray | list[int]) -> pa.Table: rows = [self.get_row(int(i)) for i in list(indices)] serializable = [] for row in rows: converted = dict(row) if hasattr(converted.get("video"), "tolist"): converted["video"] = converted["video"].tolist() serializable.append(converted) return pa.Table.from_pylist(serializable) if serializable else pa.table({})
[docs] def slice(self, start: int, length: int) -> pa.Table: return self.take(list(range(start, start + length)))
[docs] def iter_batches( self, shard_indices: list[int] | None = None, shuffle: bool = False, seed: int | None = None, ): if shard_indices is not None and 0 not in shard_indices: return indices = np.arange(self.num_rows, dtype=np.int64) if shuffle: rng = np.random.default_rng(seed) rng.shuffle(indices) for idx in indices: yield from self.take([int(idx)]).to_batches()
def __getstate__(self) -> dict: return { "uri": str(self._uri), "window_length": self.window_length, "frame_skip": self.frame_skip, "hop_size": self.hop_size, "batch_readahead": self._batch_readahead, } def __setstate__(self, state: dict) -> None: self.__init__( uri=state["uri"], window_length=state["window_length"], frame_skip=state.get("frame_skip", 0), hop_size=state.get("hop_size", 1), batch_readahead=state.get("batch_readahead", 8), )
[docs] @staticmethod def worker_init(worker_id: int) -> None: del worker_id reset_worker_state()
@staticmethod def _decode_blobs(blobs: list[bytes], height: int, width: int) -> np.ndarray: import cv2 out = np.empty((len(blobs), height, width, 3), dtype=np.uint8) for idx, blob in enumerate(blobs): bgr = cv2.imdecode(np.frombuffer(blob, dtype=np.uint8), cv2.IMREAD_COLOR) if bgr is None: raise ValueError(f"Failed to decode video frame blob at position {idx}") out[idx] = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) return out
__all__ = ["LanceVideoFramesBackend", "reset_worker_state"]