Dataset handling


stable_worldmodel provides a flexible dataset API that supports both HDF5-based storage (for speed and compactness) and Folder-based storage.

[ Storage Formats ]

The HDF5Dataset stores all data in a single .h5 file. This is the default format for recording rollouts using World.record_dataset.

File Structure:

dataset_name.h5
├── pixels          # (Total_Steps, C, H, W) or (Total_Steps, H, W, C)
├── action          # (Total_Steps, Action_Dim)
├── reward          # (Total_Steps,)
├── terminated      # (Total_Steps,)
├── ep_len          # (Num_Episodes,) - Length of each episode
└── ep_offset       # (Num_Episodes,) - Start index of each episode

Usage:

from stable_worldmodel.data import HDF5Dataset

dataset = HDF5Dataset(
    name="my_dataset",
    frameskip=1,
    num_steps=50  # Sequence length for training
)

The FolderDataset stores metadata in .npz files and heavy media (images) as individual files.

File Structure:

dataset_name/
├── ep_len.npz      # Contains 'arr_0': Array of episode lengths
├── ep_offset.npz   # Contains 'arr_0': Array of episode start offsets
├── action.npz      # Contains 'arr_0': Full array of actions
├── reward.npz      # Contains 'arr_0': Full array of rewards
└── pixels/         # Folder for image data
    ├── ep_0_step_0.jpg
    ├── ep_0_step_1.jpg
    └── ...

Usage:

from stable_worldmodel.data import FolderDataset

dataset = FolderDataset(
    name="my_image_dataset",
    folder_keys=["pixels"]  # Keys to load from folders instead of .npz
)

The VideoDataset is a specialized FolderDataset that reads frames directly from MP4 files using decord. This saves significant disk space compared to storing individual images.

File Structure:

dataset_name/
├── ep_len.npz
├── ep_offset.npz
├── action.npz
└── video/          # Folder for video files
    ├── ep_0.mp4
    ├── ep_1.mp4
    └── ...

Usage:

from stable_worldmodel.data import VideoDataset

dataset = VideoDataset(
    name="my_video_dataset",
    video_keys=["video"]
)

The ImageDataset is a convenience alias for FolderDataset with image defaults. It assumes 'pixels' is stored as individual image files.

File Structure:

dataset_name/
├── ep_len.npz
├── ep_offset.npz
├── action.npz
└── pixels/         # Folder for image files
    ├── ep_0_step_0.jpeg
    ├── ep_0_step_1.jpeg
    └── ...

Usage:

from stable_worldmodel.data import ImageDataset

dataset = ImageDataset(
    name="my_image_dataset",
    image_keys=["pixels"]  # Default
)

The GoalDataset wraps any dataset to add goal observations for goal-conditioned learning. Goals are sampled from random states, future states in the same episode, or the current state.

Usage:

from stable_worldmodel.data import HDF5Dataset, GoalDataset

# Wrap any base dataset
base_dataset = HDF5Dataset(name="my_dataset", num_steps=50)
goal_dataset = GoalDataset(
    base_dataset,
    goal_probabilities=(0.3, 0.5, 0.2),  # (random, future, current)
    gamma=0.99,  # Discount for future sampling
    seed=42
)

# Items now include goal_pixels and goal_proprio keys
item = goal_dataset[0]

[ Base Class ]

Dataset

Dataset(
    lengths: ndarray,
    offsets: ndarray,
    frameskip: int = 1,
    num_steps: int = 1,
    transform: Callable[[dict], dict] | None = None,
)

Base class for episode-based datasets.

Parameters:

  • lengths (ndarray) –

    Array of episode lengths.

  • offsets (ndarray) –

    Array of episode start offsets in the data.

  • frameskip (int, default: 1 ) –

    Number of frames to skip between samples.

  • num_steps (int, default: 1 ) –

    Number of steps per sample.

  • transform (Callable[[dict], dict] | None, default: None ) –

    Optional transform to apply to loaded data.

__getitem__

__getitem__(idx: int) -> dict

load_episode

load_episode(episode_idx: int) -> dict

Load full episode by index.

load_chunk

load_chunk(
    episodes_idx: ndarray, start: ndarray, end: ndarray
) -> list[dict]

[ Implementations ]

HDF5Dataset

HDF5Dataset(
    name: str,
    frameskip: int = 1,
    num_steps: int = 1,
    transform: Callable[[dict], dict] | None = None,
    keys_to_load: list[str] | None = None,
    keys_to_cache: list[str] | None = None,
    keys_to_merge: dict[str, list[str] | str] | None = None,
    cache_dir: str | Path | None = None,
)

Bases: Dataset

Dataset loading from HDF5 file.

Reads data from a single .h5 file containing all episode data. Uses SWMR mode for robust reading while writing.

Parameters:

  • name (str) –

    Name of the dataset (filename without extension).

  • frameskip (int, default: 1 ) –

    Number of frames to skip between samples.

  • num_steps (int, default: 1 ) –

    Number of steps per sample sequence.

  • transform (Callable[[dict], dict] | None, default: None ) –

    Optional data transform callable.

  • keys_to_load (list[str] | None, default: None ) –

    Specific keys to load (defaults to all except metadata).

  • keys_to_cache (list[str] | None, default: None ) –

    Keys to load entirely into memory for faster access.

  • cache_dir (str | Path | None, default: None ) –

    Directory containing the dataset file.

FolderDataset

FolderDataset(
    name: str,
    frameskip: int = 1,
    num_steps: int = 1,
    transform: Callable[[dict], dict] | None = None,
    keys_to_load: list[str] | None = None,
    folder_keys: list[str] | None = None,
    cache_dir: str | Path | None = None,
)

Bases: Dataset

Dataset loading from folder structure.

Metadata is stored in .npz files, heavy media (images) can be stored as individual files.

Parameters:

  • name (str) –

    Name of the dataset folder.

  • frameskip (int, default: 1 ) –

    Number of frames to skip.

  • num_steps (int, default: 1 ) –

    Sequence length.

  • transform (Callable[[dict], dict] | None, default: None ) –

    Optional transform.

  • keys_to_load (list[str] | None, default: None ) –

    Specific keys to load.

  • folder_keys (list[str] | None, default: None ) –

    Keys that correspond to folders of image files.

  • cache_dir (str | Path | None, default: None ) –

    Base directory containing the dataset folder.

VideoDataset

VideoDataset(
    name: str,
    video_keys: list[str] | None = None,
    **kw: Any,
)

Bases: FolderDataset

Dataset loading video frames from MP4 files using decord.

Assumes video files are stored in a folder structure.

ImageDataset

ImageDataset(
    name: str,
    image_keys: list[str] | None = None,
    **kw: Any,
)

Bases: FolderDataset

Convenience alias for FolderDataset with image defaults.

Assumes 'pixels' is a folder of images.

[ Wrappers ]

GoalDataset

GoalDataset(
    dataset: Dataset,
    goal_probabilities: tuple[
        float, float, float, float
    ] = (0.3, 0.5, 0.0, 0.2),
    gamma: float = 0.99,
    current_goal_offset: int | None = None,
    goal_keys: dict[str, str] | None = None,
    seed: int | None = None,
)

Dataset wrapper that samples an additional goal observation per item.

Works with any dataset type (HDF5Dataset, FolderDataset, VideoDataset, etc.)

Goals are sampled from
  • random state (uniform over dataset steps)
  • geometric future state in same episode (Geom(1-gamma))
  • uniform future state in same episode (uniform over future steps)
  • current state

with probabilities (0.3, 0.5, 0.0, 0.2) by default.

goal_probabilities: Tuple of (p_random, p_geometric_future, p_uniform_future, p_current) for goal sampling.
gamma: Discount factor for geometric future goal sampling.
current_goal_offset: Number of frames from clip start for "current" goal sampling.
    If None, defaults to num_steps, i.e., last frame of clip.
    When training with history, set this to history_size so "current" means last frame of history.
goal_keys: Mapping of source observation keys to goal observation keys. If None, defaults to {"pixels": "goal", "proprio": "goal_proprio"}.
seed: Random seed for goal sampling.

[ Utilities ]

MergeDataset

MergeDataset(
    datasets: list[Any],
    keys_from_dataset: list[list[str]] | None = None,
)

Merges multiple datasets of same length (horizontal join).

Combines columns from different datasets (e.g. one dataset has 'pixels', another has 'rewards') into a single view.

Parameters:

  • datasets (list[Any]) –

    List of dataset instances to merge.

  • keys_from_dataset (list[list[str]] | None, default: None ) –

    Optional list of keys to take from each dataset.

ConcatDataset

ConcatDataset(datasets: list[Any])

Concatenates multiple datasets (vertical join).

Combines datasets sequentially to increase the total number of episodes/samples.

Parameters:

  • datasets (list[Any]) –

    List of datasets to concatenate.