stable_worldmodel ships a small, pluggable data layer built around a
format registry. A format is a recipe for reading and writing a
particular on-disk layout (HDF5, a folder of frames, MP4 episodes, …). All
built-in datasets — and any custom one you write — go through the same
registry, so the rest of the library (e.g. World.collect, swm.data.load_dataset,
swm.data.convert) doesn't care which backend you pick.
┌──────────────────────────────────────────────┐
│ FORMAT REGISTRY │
│ lance │ hdf5 │ folder │ video │ lerobot │ … │
└──────────────────────────────────────────────┘
▲ ▲
detect ─────┘ └───── @register_format
open_reader / open_writer
│
┌─────────────┴─────────────┐
▼ ▼
load_dataset(...) World.collect(..., format=)
convert(src, dst)
[ Quick Tour ]
import stable_worldmodel as swm
# 1) Record some episodes — World picks the writer registered under `format`.
world = swm.World('swm/PushT-v1', num_envs=4, image_shape=(64, 64))
world.set_policy(swm.policy.RandomPolicy(seed=0))
world.collect('data/pusht.lance', episodes=20, seed=0) # lance (default)
world.collect('data/pusht_video', episodes=20, format='video') # mp4 + npz
# 2) Load any dataset — format is autodetected from the path.
ds = swm.data.load_dataset('data/pusht.lance', num_steps=4, frameskip=1)
sample = ds[0]
print(sample['pixels'].shape, sample['action'].shape) # (4, 3, H, W) (4, A)
# 3) Switch backends without changing your collection code.
swm.data.convert('data/pusht.lance', 'data/pusht_video', dest_format='video', fps=30)
load_dataset accepts:
- a local path (file or directory),
- a HuggingFace dataset repo (
<user>/<repo>) — auto-downloaded and cached under$STABLEWM_HOME/datasets/, - a scheme-prefixed identifier (e.g.
lerobot://lerobot/pusht) handed straight to the matching format.
ds = swm.data.load_dataset('lerobot://lerobot/pusht',
primary_camera_key='observation.images.top',
num_steps=8)
[ Storage Formats ]
All built-in formats expose the same Dataset API: each item is a dict of
tensors stacked across num_steps, with image columns transposed to
(T, C, H, W).
The lance format stores episodes as a single LanceDB table laid out
in episode-contiguous flat rows. Image columns are kept as JPEG blobs in
pa.binary columns; tabular columns become fixed-size lists of float32.
Two index columns (episode_idx, step_idx) let the reader recover
episode boundaries by scanning a single column. It's the default for
World.collect.
Why this is the default for WM training. World-model and VLA workloads mix large image trajectories with shuffled, random-access reads — a pattern that suffers under one-file-per-episode formats. This backend was chosen for two properties:
- Column-projected reads.
keys_to_load=['pixels', 'action']fetches only those columns; unused fields (state,reward, additional cameras) are never read or decoded. - Streaming from object storage.
pathaccepts ans3://,gs://, oraz://URI (credentials viaconnect_kwargs); the reader pulls byte ranges on demand, so a training run does not need a local copy of a multi-TB dataset.
Combined with row-level random access via episode_idx / step_idx,
shuffled DataLoader workers stay fed without per-episode file-open
overhead — which is the bottleneck at the dataset sizes typical of
modern VLA training.
Layout:
dataset.lance/ # LanceDB table directory
├── _versions/ # transaction log
├── data/ # column fragments (Arrow IPC)
└── _indices/ # secondary indexes (none by default)
Read:
import stable_worldmodel as swm
ds = swm.data.load_dataset('data/pusht.lance', num_steps=16, frameskip=1,
keys_to_load=['pixels', 'action'])
Write:
from stable_worldmodel.data import LanceWriter
with LanceWriter('data/pusht.lance') as w: # mode='append' default
for ep in episodes:
w.write_episode(ep)
World.collect(path, episodes=...) defaults to this format.
Field names with . are renamed to _ on disk because Lance reserves
. as a struct-path separator (pixels.top → pixels_top). The
reader still references columns by their renamed name.
The hdf5 format stores everything in a single .h5 file with one
dataset per column plus the ep_len/ep_offset index. Useful when you
want a portable single-file artifact.
Layout:
dataset.h5
├── pixels # (Total_Steps, H, W, C) uint8
├── action # (Total_Steps, Action_Dim) float32
├── reward # (Total_Steps,) float32
├── ep_len # (Num_Episodes,) int32
└── ep_offset # (Num_Episodes,) int64
Read:
import stable_worldmodel as swm
ds = swm.data.load_dataset('data/pusht.h5', num_steps=16, frameskip=1,
keys_to_load=['pixels', 'action'])
Write:
from stable_worldmodel.data import HDF5Writer
with HDF5Writer('data/pusht.h5') as w:
for ep in episodes: # ep = {col: [step_arr, ...]}
w.write_episode(ep)
To use this format with World.collect, pass format='hdf5' explicitly.
The folder format keeps tabular columns as .npz arrays and image
columns as one JPEG per step. It's great when you want to inspect frames
on disk or stream a few keys without paying HDF5's open cost.
Layout:
dataset/
├── ep_len.npz # (N,) int32
├── ep_offset.npz # (N,) int64
├── action.npz # (Total_Steps, A)
├── reward.npz # (Total_Steps,)
└── pixels/ # one image per step
├── ep_0_step_0.jpeg
├── ep_0_step_1.jpeg
└── ...
Read: image columns are inferred from subdirectories, so folder_keys
is rarely needed.
ds = swm.data.load_dataset('data/pusht_folder/', num_steps=4)
Write: any uint8 (H, W, 3) or (H, W, 1) array is auto-detected
as an image column and saved as JPEG.
from stable_worldmodel.data import FolderWriter
with FolderWriter('data/pusht_folder') as w:
w.write_episode({'pixels': frames, 'action': actions})
The video format is identical to folder for tabular columns, but
encodes each image column as one MP4 per episode. Frames are decoded with
decord, which makes it a good fit for
long episodes where storing raw JPEGs is wasteful.
Layout:
dataset/
├── ep_len.npz, ep_offset.npz, action.npz, ...
└── video/
├── ep_0.mp4
└── ep_1.mp4
Read / Write:
ds = swm.data.load_dataset('data/pusht_video/', num_steps=8)
# direct write
from stable_worldmodel.data import VideoWriter
with VideoWriter('data/pusht_video', fps=30, codec='libx264') as w:
w.write_episode(episode)
# or via World
world.collect('data/pusht_video', episodes=100, format='video')
video requires the optional decord dependency for reading and
imageio (with an FFmpeg backend) for writing.
The lerobot format is a read-only adapter over
lerobot.datasets.LeRobotDataset.
It's identified by the lerobot:// scheme and exposes the same episode-based
API as the native SWM datasets: by default the primary camera is mapped to
pixels, action to action, and observation.state to proprio.
ds = swm.data.load_dataset(
'lerobot://lerobot/pusht',
primary_camera_key='observation.images.top', # → 'pixels'
num_steps=8,
keys_to_load=['pixels', 'action', 'proprio', 'ep_idx', 'step_idx'],
keys_to_cache=['action', 'proprio', 'ep_idx', 'step_idx'],
)
LeRobot support is feature-gated to Python 3.12+ because the upstream
lerobot package requires it. Install with
pip install 'stable-worldmodel[lerobot]'. There is no lerobot writer —
mapping arbitrary World info dicts onto LeRobot's schema is not supported.
GoalDataset wraps any of the formats above to add a sampled goal
observation per item, for goal-conditioned learning. Goals are drawn from
one of four buckets (random, geometric future, uniform future, current)
according to a probability vector.
from stable_worldmodel.data import GoalDataset
base = swm.data.load_dataset('data/pusht.h5', num_steps=4)
goal = GoalDataset(
base,
goal_probabilities=(0.3, 0.5, 0.0, 0.2), # random, geom. future, uniform future, current
gamma=0.99,
seed=42,
)
item = goal[0] # adds 'goal_pixels', 'goal_proprio'
[ Write Modes ]
Every built-in writer accepts a standard mode kwarg with three values
(default: 'append'):
| mode | when target exists | when target is missing |
|---|---|---|
'append' |
extend with new episodes (validates schema) | create from scratch |
'overwrite' |
drop existing data, then write fresh | create from scratch |
'error' |
raise FileExistsError |
create from scratch |
'append' is the default so that re-running a collection script naturally
extends the dataset rather than failing or wiping prior work. Lengths,
offsets, and episode indexes are resumed from the on-disk state. Image
columns continue to use the next available ep_idx (or row offset) so
appended frames don't clobber existing files.
Schema is validated against the on-disk dataset before any new bytes are
written. A column added, removed, retyped (image vs. tabular), or with a
mismatched per-step shape raises a clear ValueError:
from stable_worldmodel.data import HDF5Writer
with HDF5Writer('data/pusht.h5') as w: # extends existing
w.write_episode(ep)
with HDF5Writer('data/pusht.h5', mode='overwrite') as w: # starts fresh
w.write_episode(ep)
with HDF5Writer('data/pusht.h5', mode='error') as w: # raises if it exists
w.write_episode(ep)
The same mode kwarg works for FolderWriter, VideoWriter, and
LanceWriter, and is forwarded by World.collect(...) and
swm.data.convert(...) via their writer-kwarg passthrough.
[ Converting Between Formats ]
convert() walks each episode of a source dataset and writes it through the
writer of dest_format. Source format is autodetected unless you pass
source_format=.
from stable_worldmodel.data import convert
# HDF5 → MP4 directory (fps forwarded to VideoWriter)
convert('data/pusht.h5', 'data/pusht_video',
dest_format='video', fps=30)
# Folder → HDF5 (good for shrinking many JPEGs into one file)
convert('data/pusht_folder', 'data/pusht.h5', dest_format='hdf5')
This composes with load_dataset's resolution rules, so you can convert
straight from a HuggingFace repo or a lerobot:// URL:
convert('lerobot://lerobot/pusht', 'data/pusht_local',
source_format='lerobot', dest_format='video',
primary_camera_key='observation.images.top')
[ Registering a Custom Format ]
A format is just a class with three classmethods. Decorate it with
@register_format and the rest of the stack picks it up.
from stable_worldmodel.data import Format, register_format
from stable_worldmodel.data.dataset import Dataset
@register_format
class Parquet(Format):
name = 'parquet'
@classmethod
def detect(cls, path):
from pathlib import Path
return Path(path).suffix == '.parquet'
@classmethod
def open_reader(cls, path, **kw):
return ParquetDataset(path, **kw) # subclass of Dataset
@classmethod
def open_writer(cls, path, **kw):
return ParquetWriter(path, **kw) # __enter__/__exit__/write_episode
Once imported, your format is usable everywhere:
swm.data.load_dataset('foo.parquet') # reader
world.collect('foo.parquet', episodes=10, format='parquet') # writer
swm.data.list_formats() # ['hdf5', 'folder', 'video', 'lerobot', 'parquet']
Read-only formats simply omit open_writer; write-only formats omit
open_reader. Both calls raise a clear error by default. If your writer
should participate in the standard mode contract, accept a mode kwarg and
delegate validation to validate_write_mode:
from stable_worldmodel.data.format import validate_write_mode, WRITE_MODES
class ParquetWriter:
def __init__(self, path, *, mode: str = 'append'):
validate_write_mode(mode) # rejects values outside WRITE_MODES
self.mode = mode
...
[ Base Class ]
Dataset
Dataset(
lengths: ndarray,
offsets: ndarray,
frameskip: int = 1,
num_steps: int = 1,
transform: Callable[[dict], dict] | None = None,
)
Base class for episode-based datasets.
Subclasses fill in column_names and _load_slice; everything else
(clip indexing, __getitem__, load_chunk, load_episode) is
derived here.
Parameters:
-
lengths(ndarray) –Episode lengths.
-
offsets(ndarray) –Episode start offsets in the underlying flat storage.
-
frameskip(int, default:1) –Stride between observation samples.
-
num_steps(int, default:1) –Number of observation steps per sample.
-
transform(Callable[[dict], dict] | None, default:None) –Optional dict-in / dict-out transform applied per sample.
[ Implementations ]
LanceDataset
LanceDataset(
path: str | Path | None = None,
table_name: str | None = None,
*,
uri: str | None = None,
frameskip: int = 1,
num_steps: int = 1,
transform: Callable[[dict], dict] | None = None,
keys_to_load: list[str] | None = None,
keys_to_cache: list[str] | None = None,
keys_to_merge: dict[str, list[str] | str] | None = None,
image_columns: list[str] | None = None,
episode_index_column: str = 'episode_idx',
step_index_column: str = 'step_idx',
connect_kwargs: dict[str, Any] | None = None,
)
Bases: Dataset
Reader for a LanceDB table written by :class:LanceWriter.
Parameters:
-
path(str | Path | None, default:None) –Either a
.lancedirectory path or a database URI. -
table_name(str | None, default:None) –Table inside the database; inferred from a
.lancepath when omitted. -
image_columns(list[str] | None, default:None) –override image-column auto-detection (any
pa.binarycolumn is treated as encoded image by default). -
episode_index_column, step_index_column–index column names.
-
connect_kwargs(dict[str, Any] | None, default:None) –forwarded to :func:
lancedb.connect(e.g. S3 creds).
HDF5Dataset
HDF5Dataset(
name: str | None = None,
frameskip: int = 1,
num_steps: int = 1,
transform: Callable[[dict], dict] | None = None,
keys_to_load: list[str] | None = None,
keys_to_cache: list[str] | None = None,
keys_to_merge: dict[str, list[str] | str] | None = None,
cache_dir: str | Path | None = None,
path: str | Path | None = None,
storage_options: dict | None = None,
)
Bases: Dataset
Dataset loading from a single HDF5 file (SWMR mode for safe reads).
For remote paths (s3://, gs://, etc.), pass storage_options
that fsspec recognises for the chosen scheme. The file handle is opened
lazily per-worker, so DataLoader multiprocessing is supported.
FolderDataset
FolderDataset(
name: str | None = None,
frameskip: int = 1,
num_steps: int = 1,
transform: Callable[[dict], dict] | None = None,
keys_to_load: list[str] | None = None,
folder_keys: list[str] | None = None,
cache_dir: str | Path | None = None,
path: str | Path | None = None,
)
Bases: Dataset
Dataset loading from a folder structure.
Tabular columns are stored as .npz files; image columns are stored as
one image file per step under <key>/ep_<i>_step_<j>.jpeg.
VideoDataset
Bases: FolderDataset
Loads frames from MP4 files (one per episode).
Frames are decoded with decord where it ships wheels (Linux, Windows)
and with PyAV (av) elsewhere, notably macOS arm64, where decord has
no wheel. Both backends share the same reader API, so only the backend
selection differs across platforms.
ImageDataset
LeRobotAdapter
LeRobotAdapter(
repo_id: str,
root: str | Path | None = None,
episodes: list[int] | None = None,
frameskip: int = 1,
num_steps: int = 1,
transform: Callable[[dict], dict] | None = None,
keys_to_load: list[str] | None = None,
keys_to_cache: list[str] | None = None,
primary_camera_key: str | None = None,
key_aliases: dict[str, str] | None = None,
**lerobot_kwargs: Any,
)
[ Replay Buffer ]
ReplayBuffer is an in-memory ring-storage buffer that subclasses Dataset
and implements the Writer protocol — the same object can be filled by a
rollout (Writer side) and iterated by a DataLoader (Dataset side), so
collection and training can interleave without copying data. It evicts
whole oldest episodes FIFO when the next write would exceed max_steps,
and dump(path, format=...) persists current contents through any
registered format writer.
See the online-learning guide for fill/sample/dump examples and a step-conditioned sampler walkthrough.
ReplayBuffer
ReplayBuffer(
max_steps: int,
history_len: int = 1,
frameskip: int = 1,
sampler: Sampler | None = None,
transform: Callable[[dict], dict] | None = None,
key_filter: Callable[[dict], dict] | None = None,
)
Bases: Dataset
In-memory ring-storage replay buffer.
Parameters:
-
max_steps(int) –Capacity in steps. Whole episodes are evicted FIFO when adding a new episode would exceed this.
-
history_len(int, default:1) –Steps per clip returned by
__getitem__(the Dataset path) and the default forsample(...). Equivalent to Dataset'snum_steps. -
frameskip(int, default:1) –Stride between observation samples within a clip. Action columns are kept dense and reshaped to
(history_len, frameskip * action_dim), matching :class:FolderDataset. -
sampler(Sampler | None, default:None) –fn(step, buffer, batch_size, history_len) -> indicesreturning flat clip indices in[0, buffer.num_valid_ends(history_len)). Default is uniform. -
transform(Callable[[dict], dict] | None, default:None) –Optional dict-in / dict-out transform applied per clip in the Dataset path (
__getitem__). -
key_filter(Callable[[dict], dict] | None, default:None) –Optional
fn(ep_data) -> ep_dataapplied to each episode inwrite_episode, returning the subset of columns to store.None(default) stores every column.
[ Wrappers ]
GoalDataset
GoalDataset(
dataset: Dataset,
goal_probabilities: tuple[
float, float, float, float
] = (0.3, 0.5, 0.0, 0.2),
gamma: float = 0.99,
current_goal_offset: int | None = None,
goal_keys: dict[str, str] | None = None,
seed: int | None = None,
)
Wrap any dataset to return a sampled goal observation per item.
Goals are sampled from one of
- random state (uniform over all dataset steps)
- geometric future state in same episode (Geom(1-gamma))
- uniform future state in same episode
- current state
with probabilities (0.3, 0.5, 0.0, 0.2) by default.
MergeDataset
ConcatDataset
Concatenate datasets sequentially (vertical join, more episodes).
[ Normalization ]
Per-column scalers for normalizing heterogeneous numeric columns (actions, proprio, state) before training. Each one implements three interfaces at once:
- sklearn:
fit,transform,inverse_transform,fit_transform. Transformableprotocol (seestable_worldmodel.policy):transform/inverse_transform.- callable:
scaler(x)returnstransform(x), cast tofloat32whenxis a tensor — exactly whatWrapTorchTransformconsumes.
| Method | Class | Behavior |
|---|---|---|
'zscore' |
ZScoreScaler |
(x - mean) / std. Default. Sensitive to outliers. |
'percentile' |
PercentileScaler |
Per-dim [q_low, q_high] → [-1, 1], clipped. Robust. |
'none' |
IdentityScaler |
Pass-through; for columns already in a usable range. |
Plug into a transform pipeline. column_normalizer fetches the column,
fits the chosen scaler, and wraps it as a WrapTorchTransform:
import stable_worldmodel as swm
from stable_worldmodel.data import column_normalizer
dataset = swm.data.load_dataset('data/pusht.lance')
transforms = [
column_normalizer(dataset, 'action', 'action', method='zscore'),
column_normalizer(dataset, 'proprio', 'proprio', method='percentile'),
column_normalizer(dataset, 'state', 'state', method='none'),
]
Drive it from a Hydra config with a per-column mapping and 'zscore'
as the default for unlisted columns:
dataset:
normalizers:
action: zscore
proprio: percentile
state: none
normalizers_cfg = cfg.dataset.get('normalizers', {}) or {}
for col in cols_to_normalize:
method = normalizers_cfg.get(col, 'zscore')
transforms.append(column_normalizer(dataset, col, col, method=method))
Use a scaler standalone — handy for evaluation where you need the inverse to recover raw units:
from stable_worldmodel.data import PercentileScaler, get_scaler
scaler = PercentileScaler(low=1.0, high=99.0).fit(train_actions)
a_norm = scaler.transform(action) # → [-1, 1]
a_raw = scaler.inverse_transform(pred) # back to original units
# Or fetch by name, useful when method comes from a config string.
scaler = get_scaler('zscore').fit(train_actions)
column_normalizer
Build a per-column normalizer :class:WrapTorchTransform from dataset stats.
Parameters:
-
dataset–A dataset exposing
get_col_data(col)returning an array. -
source(str) –Column name to read.
-
target(str) –Column name to write.
-
method(str, default:'zscore') –One of
'zscore'(default),'percentile', or'none'.'none'returns a pass-through identity transform so call sites can stay uniform.
Returns:
-
–
A picklable :class:
WrapTorchTransformwrapping a fitted scaler.
get_scaler
get_scaler(method: str = 'zscore', **kwargs)
Return an unfitted scaler by method name.
Parameters:
-
method(str, default:'zscore') –One of
'zscore','percentile','none'. -
**kwargs–Forwarded to the scaler constructor.
Raises:
-
ValueError–If
methodis not registered.
ZScoreScaler
ZScoreScaler(mean=None, std=None, eps: float = 1e-08)
Per-dim z-score scaler: (x - mean) / std.
Stats are stored as numpy arrays so the scaler pickles without dragging torch tensors across process boundaries.
PercentileScaler
PercentileScaler(
low: float = 1.0,
high: float = 99.0,
q_low=None,
q_high=None,
eps: float = 1e-08,
)
Per-dim percentile scaler: maps to [-1, 1] using q_low/q_high
and clips. Robust to outliers compared to z-score.
IdentityScaler
No-op scaler. Use for columns that should pass through unchanged.
[ Format Registry ]
Format
Declarative format spec.
Subclasses set name and implement detect. They typically also
implement open_reader and/or open_writer.
detect_format
Return the first registered format whose detect() matches, else None.
[ Top-Level Helpers ]
load_dataset
Resolve a dataset name to a local path and dispatch to the matching format reader from the registry.
Supported names:
- Local path — file or directory.
- HuggingFace repo (
<user>/<repo>) — downloaded and cached under<cache_dir>/datasets/<user>--<repo>/. - Format scheme (e.g.
lerobot://lerobot/pusht) — passed through to the matching format unchanged.
The format is auto-detected via :func:detect_format unless format is
provided explicitly. To register a new format, decorate a
:class:~stable_worldmodel.data.format.Format subclass with
:func:~stable_worldmodel.data.format.register_format.
Parameters:
-
name(str) –Local path, HF repo id, or scheme-prefixed identifier.
-
cache_dir(str, default:None) –Root cache directory. Defaults to
STABLEWM_HOMEor~/.stable_worldmodel. -
format(str | None, default:None) –Explicit format name (skips detection).
-
**kwargs–Forwarded to the format's reader.
Returns:
-
–
A reader instance (typically a
-
–
class:
~stable_worldmodel.data.dataset.Datasetsubclass).
convert
convert(
source,
dest,
*,
source_format: str | None = None,
dest_format: str = 'lance',
cache_dir: str | None = None,
progress: bool = True,
**dest_kwargs,
) -> None
Convert a dataset from one registered format to another.
Reads each episode from source and writes it through the writer of
dest_format. Format detection follows the same rules as
:func:load_dataset — autodetect by default, or pass source_format
explicitly.
Parameters:
-
source–Path or identifier accepted by :func:
load_dataset. -
dest–Output path for the destination writer.
-
source_format(str | None, default:None) –Force a source format (skips detection).
-
dest_format(str, default:'lance') –Registered writer name (default
'lance'). -
cache_dir(str | None, default:None) –Forwarded to the source loader for HF/local resolution.
-
progress(bool, default:True) –Show a progress bar over episodes.
-
**dest_kwargs–Forwarded to the destination writer.
Example::
from stable_worldmodel.data import convert
convert('data.lance', 'data_video', dest_format='video')