HFDataset#

class stable_pretraining.data.HFDataset(*args, transform=None, gpu_transform=None, rename_columns=None, remove_columns=None, **kwargs)[source]#

Bases:

Create a HuggingFace dataset wrapper.

Automatically chooses map-style or streaming based on streaming=True/False in kwargs.

The returned object is either an HFMapDataset (subclass of torch.utils.data.Dataset) or an HFIterableDataset (subclass of torch.utils.data.IterableDataset), so PyTorch DataLoader and Lightning Trainer handle both correctly out of the box.

Parameters:
  • *args – Positional arguments forwarded to datasets.load_dataset (typically the dataset name/path).

  • transform – Optional transform applied to every sample dict.

  • gpu_transform – Optional batch-level GPU transform stored on the returned dataset (see DatasetMixin). Discovered by Module.on_after_batch_transfer() and run on the post-collated batch after device transfer.

  • rename_columns – Optional {old: new} mapping of columns to rename.

  • remove_columns – Optional list of column names to drop.

  • **kwargs – Keyword arguments forwarded to datasets.load_dataset (e.g. split, streaming, data_dir).

Returns:

An HFMapDataset or HFIterableDataset instance.

Example:

# Map-style
ds = HFDataset("imagenet-1k", split="train")
print(len(ds))  # works

# Streaming
ds = HFDataset("imagenet-1k", split="train", streaming=True)
ds.shuffle(seed=42, buffer_size=10_000)
for sample in ds:
    ...

Examples using HFDataset:#

<no title>

Supervised ImageNet-1k ViT training — SOTA (DeiT/AugReg) recipe, FSDP2, GPU-fast.

<no title>

SimCLR on ImageNette with the JAX / Flax-NNX backend, data-parallel over GPUs.

Multi-layer Probe for Vision Models

Multi-layer Probe for Vision Models

Supervised Learning Example

Supervised Learning Example