Source code for stable_datasets.samplers

"""Backend-aware samplers for :class:`StableDataset`.

PyTorch's :class:`~torch.utils.data.DataLoader` constructs a
:class:`~torch.utils.data.RandomSampler` when ``shuffle=True`` is
passed. That sampler yields indices in a full-random permutation
regardless of the underlying storage backend. For file-backed
storage formats partitioned into shards (Arrow) or fragments
(Lance), full-random access destroys any per-shard I/O locality
the format was designed to exploit.

This module exposes samplers that yield indices in shard-aware
orderings, preserving the classical PyTorch API (``DataLoader(ds,
sampler=...)``) while providing a sampler that matches the
backend's access-pattern preferences:

    from stable_datasets.samplers import ShardShuffleSampler

    ds = CIFAR10(split="train", storage_format="lance")
    sampler = ShardShuffleSampler(ds, seed=42)
    loader = DataLoader(ds, batch_size=128, sampler=sampler,
                        num_workers=8, persistent_workers=True,
                        multiprocessing_context="spawn")

``DataLoader(ds, shuffle=True)`` continues to work unchanged for
users who need bit-exact full-random ordering (e.g. classification
reproduction). Samplers here are strictly opt-in.

See also
--------
``torch.utils.data.Sampler`` : base class.
``lance.sampler.ShardedFragmentSampler`` : Lance's own fragment
   sampler for its native :class:`lance.torch.data.LanceDataset`
   integration. ``ShardShuffleSampler`` is the nearest equivalent
   exposed through the StableDataset backend protocol.
"""

from __future__ import annotations

from collections.abc import Iterator
from typing import Literal

import numpy as np
from torch.utils.data import Sampler


[docs] class ShardShuffleSampler(Sampler[int]): """Yield indices in shard-shuffled order. The shard (or Lance fragment) order is randomized each epoch. Within each shard, indices are yielded in an order controlled by ``within_shard``: * ``"random"`` (default): indices inside a shard are themselves permuted. Shuffle quality is closer to full-random while still preserving per-shard I/O locality (all samples from shard *k* are emitted before any sample from shard *k+1*). Recommended for scientific training where shuffle quality matters. * ``"sequential"``: indices inside a shard are yielded in on- disk order. Maximally I/O-friendly but shuffle quality is coarse at the shard level. Matches the behaviour of ``lance.sampler.ShardedFragmentSampler``. Parameters ---------- dataset : StableDataset Must expose a ``StorageBackend``-compatible ``._backend`` with ``num_shards`` and a way to iterate per-shard row ranges. Non-file-backed datasets fall back to a single shard covering the full dataset. seed : int, default 0 Base seed; the epoch is XOR'd in via :meth:`set_epoch`. within_shard : {"random", "sequential"}, default "random" Within-shard row ordering. Notes ----- *Epoch handling*: call :meth:`set_epoch` before each epoch when using :class:`~torch.utils.data.distributed.DistributedSampler` or any other stateful epoch pattern, so the random permutation differs between epochs. Mirrors PyTorch's own convention. *Fork-safety*: the sampler holds only integers and a seed; it pickles trivially and is safe to use with ``num_workers>0`` and ``multiprocessing_context="spawn"``. """ def __init__( self, dataset, *, seed: int = 0, within_shard: Literal["random", "sequential"] = "random", ): if within_shard not in ("random", "sequential"): raise ValueError(f"within_shard must be 'random' or 'sequential', got {within_shard!r}") self._n = len(dataset) self._seed = int(seed) self._within_shard = within_shard self._epoch = 0 self._shard_ranges = self._compute_shard_ranges(dataset) @staticmethod def _compute_shard_ranges(dataset) -> list[tuple[int, int]]: """Return [(start, end_exclusive), ...] per shard. File-backed backends with ``num_shards`` and per-shard row counts are partitioned into contiguous per-shard ranges. For other backends (e.g. in-memory tables, indexed views), the whole dataset becomes a single shard. """ n = len(dataset) backend = getattr(dataset, "_backend", None) if backend is None or not getattr(backend, "is_file_backed", False): return [(0, n)] # Accept both public and private per-shard row-count attributes. shard_row_counts = getattr(backend, "shard_row_counts", None) or getattr(backend, "_shard_row_counts", None) if shard_row_counts is None: # Lance fragments provide shard row counts when available. try: fragments = backend._dataset.get_fragments() shard_row_counts = [f.count_rows() for f in fragments] except Exception: return [(0, n)] ranges: list[tuple[int, int]] = [] start = 0 for c in shard_row_counts: ranges.append((start, start + int(c))) start += int(c) if start != n: # Inconsistent row counts are treated as a single shard. return [(0, n)] return ranges
[docs] def set_epoch(self, epoch: int) -> None: self._epoch = int(epoch)
def __len__(self) -> int: return self._n def __iter__(self) -> Iterator[int]: rng = np.random.default_rng(self._seed ^ (self._epoch * 0x9E3779B1)) shard_order = np.arange(len(self._shard_ranges)) rng.shuffle(shard_order) for shard_idx in shard_order: start, end = self._shard_ranges[shard_idx] shard_indices = np.arange(start, end) if self._within_shard == "random": rng.shuffle(shard_indices) for idx in shard_indices: yield int(idx)