Create a small PushT dataset, inspect it, load it for training, and convert it to another format.


This tutorial walks through the first stage of a world-model experiment: recording trajectories into a dataset that can be loaded by PyTorch training code.

You will:

  • create a vectorized World,
  • collect PushT trajectories with a weak expert policy,
  • randomize selected factors of variation,
  • inspect the dataset from the CLI,
  • load fixed-length clips for training,
  • convert the dataset to another storage format.

Install

Use the full extra for this tutorial because it needs environment and dataset format dependencies:

pip install 'stable-worldmodel[all]'

On macOS arm64, decord may not provide a compatible wheel for the all extra. The PushT collection path below can be run with the narrower runtime set instead:

pip install stable-worldmodel pygame pymunk shapely opencv-python-headless

Datasets are stored under $STABLEWM_HOME/datasets. If the variable is not set, the default root is ~/.stable_worldmodel.

export STABLEWM_HOME=$PWD/.stablewm
mkdir -p "$STABLEWM_HOME/datasets"

Inspect the environment

List the environments registered by stable_worldmodel:

swm envs

Inspect PushT's factors of variation:

swm fovs swm/PushT-v1

Factors of variation are reset-time controls such as object colors, object positions, sizes, shapes, and physics parameters. Passing a factor name in options={"variation": ...} samples that factor for each episode.

Collect trajectories

Create collect_pusht_tutorial.py:

from pathlib import Path
import os

import stable_worldmodel as swm
from stable_worldmodel.envs.pusht import WeakPolicy


root = Path(os.environ.get('STABLEWM_HOME', Path.home() / '.stable_worldmodel'))
dataset_path = root / 'datasets' / 'tutorial_pusht.lance'

world = swm.World(
    'swm/PushT-v1',
    num_envs=8,
    image_shape=(96, 96),
    max_episode_steps=100,
)

world.set_policy(WeakPolicy(dist_constraint=100, seed=0))

world.collect(
    dataset_path,
    episodes=64,
    seed=0,
    options={
        'variation': [
            'agent.start_position',
            'block.start_position',
            'block.angle',
            'agent.color',
            'block.color',
        ],
    },
)

world.close()
print(f'wrote {dataset_path}')

Run it:

python collect_pusht_tutorial.py

The default dataset format is lance, which is the recommended format for training workloads with shuffled image trajectories. To write HDF5 instead, pass format='hdf5' and use a .h5 path.

world.collect(
    root / 'datasets' / 'tutorial_pusht.h5',
    episodes=64,
    seed=0,
    format='hdf5',
)

Inspect the dataset

Datasets in $STABLEWM_HOME/datasets can be inspected by name:

swm inspect tutorial_pusht

The output includes the detected format, episode count, step count, and stored columns. All numeric entries from world.infos become columns. The standard columns include:

Column Meaning
pixels Rendered RGB frame resized by World(image_shape=...).
action Action applied at the current step.
reward Environment reward.
terminated / truncated Episode-end flags.
state / proprio Low-dimensional state when the environment provides it.
step_idx Step index within the episode.
variation.* Variation values requested in reset options.

For Lance datasets, dots in field names are normalized to underscores on disk because Lance reserves . for struct paths.

Load training clips

Load fixed-length clips with swm.data.load_dataset():

import stable_worldmodel as swm

dataset = swm.data.load_dataset(
    'tutorial_pusht.lance',
    num_steps=8,
    frameskip=1,
    keys_to_load=['pixels', 'action', 'state'],
)

sample = dataset[0]
print(sample.keys())
print(sample['pixels'].shape)  # (T, C, H, W)
print(sample['action'].shape)  # (T, action_dim)

Then use the dataset with a standard PyTorch DataLoader:

from torch.utils.data import DataLoader

loader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
)

batch = next(iter(loader))
print(batch['pixels'].shape)  # (B, T, C, H, W)

Lance datasets switch multiprocessing to the spawn start method for fork-safety. If you use num_workers > 0, run the DataLoader code from a normal Python file under an if __name__ == '__main__': guard. For notebooks, REPLs, or heredoc snippets, set num_workers=0.

num_steps is the temporal window returned by __getitem__. frameskip controls the stride between observation frames while keeping action sequences dense.

Convert formats

Convert the Lance dataset to a video-backed dataset:

swm convert tutorial_pusht tutorial_pusht_video --dest-format video

Or convert in Python:

from pathlib import Path
import os

import stable_worldmodel as swm

root = Path(os.environ.get('STABLEWM_HOME', Path.home() / '.stable_worldmodel'))

swm.data.convert(
    root / 'datasets' / 'tutorial_pusht.lance',
    root / 'datasets' / 'tutorial_pusht_video',
    dest_format='video',
    fps=30,
)

Use video when you want compact, easy-to-watch rollouts. Use lance when you want fast random access during training.

Common checks

  • If swm inspect tutorial_pusht cannot find the dataset, confirm $STABLEWM_HOME is the same in both shells.
  • If a policy returns the wrong shape, check world.envs.action_space.shape. Vectorized worlds expect actions shaped like the batched action space.
  • If you need exact reproducibility for OOD experiments, include every factor you randomize in options['variation']; only watched variation values are written into the dataset.
  • If you collect outside $STABLEWM_HOME/datasets, pass an absolute path to swm.data.load_dataset().