Source code for stable_datasets.iterable

"""Iterable dataset for streaming with worker sharding and buffered shuffle.

Provides ``StableIterableDataset`` for efficient streaming in PyTorch
DataLoader with multiple workers.  Supports shard-level worker partitioning
and reservoir-based row-level shuffle.
"""

from __future__ import annotations

from collections.abc import Callable

import numpy as np


try:
    from torch.utils.data import IterableDataset as _IterableBase
except ImportError:

    class _IterableBase:
        """Fallback base when PyTorch is not installed."""

        pass


[docs] class StableIterableDataset(_IterableBase): """An iterable-style dataset with worker sharding and buffered shuffle. Wraps a ``StableDataset`` for efficient streaming in PyTorch DataLoader with multiple workers. Shards are partitioned across workers so each worker reads a disjoint subset. Parameters ---------- dataset : StableDataset The underlying map-style dataset (must be shard-backed). shuffle : bool Whether to shuffle shard order and apply buffered row-level shuffle. seed : int Base random seed. buffer_size : int Size of the reservoir buffer for row-level shuffle. transform : callable, optional Transform applied to each yielded row dict. """ def __init__( self, dataset, *, shuffle: bool = False, seed: int = 0, buffer_size: int = 10_000, transform: Callable | None = None, ): self._dataset = dataset self._shuffle = shuffle self._seed = seed self._buffer_size = buffer_size self._transform = transform self._epoch = 0
[docs] def set_epoch(self, epoch: int): """Set the epoch for varying shuffle seed across epochs.""" self._epoch = epoch
def __iter__(self): ds = self._dataset # Non-file-backed fallback if not ds._backend.is_file_backed: for i in range(len(ds)): row = ds[i] if self._transform: row = self._transform(row) yield row return # Determine worker sharding try: from torch.utils.data import get_worker_info worker_info = get_worker_info() except ImportError: worker_info = None all_shards = list(range(ds._backend.num_shards)) num_workers = worker_info.num_workers if worker_info is not None else 1 if worker_info is not None: my_shards = all_shards[worker_info.id :: num_workers] worker_id = worker_info.id else: my_shards = all_shards worker_id = 0 # When fewer shards than workers, fall back to row-level partitioning # so all workers contribute. Each worker mmaps the same file (shared # pages via OS page cache) but yields only its interleaved rows. partition_rows = len(all_shards) < num_workers if partition_rows: my_shards = all_shards # every worker reads the same shard(s) effective_seed = self._seed + self._epoch * 1000 + worker_id rng = np.random.default_rng(effective_seed) if self._shuffle else None if self._shuffle and rng is not None: rng.shuffle(my_shards) formatter = ds._formatter def _row_gen(): if partition_rows: # Row-level partitioning: must check each row index row_idx = 0 for batch in ds._backend.iter_batches(shard_indices=my_shards): batch_dict = batch.to_pydict() n = batch.num_rows for i in range(n): if row_idx % num_workers != worker_id: row_idx += 1 continue row = {k: v[i] for k, v in batch_dict.items()} yield formatter.format_row(row) row_idx += 1 else: # Shard-partitioned: use batch formatting for less overhead for batch in ds._backend.iter_batches(shard_indices=my_shards): yield from formatter.format_batch(batch) if self._shuffle and self._buffer_size > 0: yield from self._buffered_shuffle(_row_gen(), rng) else: for row in _row_gen(): if self._transform: row = self._transform(row) yield row def _buffered_shuffle(self, row_gen, rng): """Reservoir-based buffered shuffle (Fisher-Yates). Fills a buffer from the row generator, then yields random elements as new rows arrive to maintain the buffer at capacity. """ buffer = [] for row in row_gen: if len(buffer) < self._buffer_size: buffer.append(row) else: idx = rng.integers(0, len(buffer)) out = buffer[idx] buffer[idx] = row if self._transform: out = self._transform(out) yield out # Flush remaining buffer in random order rng.shuffle(buffer) for row in buffer: if self._transform: row = self._transform(row) yield row