Runs a policy against a pool of vectorized environments, with HDF5 data collection and dataset-driven evaluation.


World is the main entry point for rolling out policies in stable_worldmodel. It bundles:

  1. A batched simulator (EnvPool) that steps num_envs envs in parallel and can skip terminated envs via a mask.
  2. A preprocessing pipeline (MegaWrapper) that resizes pixels, lifts everything into the info dict, and applies optional transforms.
  3. A rollout loop that drives policy.get_action(infos) and handles resets, per-env termination, and episode accounting.
import stable_worldmodel as swm
from stable_worldmodel.policy import RandomPolicy

world = swm.World(
    env_name="swm/PushT-v1",
    num_envs=4,
    image_shape=(64, 64),
)
world.set_policy(RandomPolicy())

# All stacked tensors in world.infos have shape (num_envs, 1, ...).
world.reset(seed=0)
# world.infos["pixels"]  -> (4, 1, 64, 64, 3)
import stable_worldmodel as swm

world = swm.World("swm/PushT-v1", num_envs=8, image_shape=(64, 64))
world.set_policy(expert_policy)

# Roll out 500 episodes in parallel and dump them to an HDF5 file.
world.collect("data/pusht_expert.h5", episodes=500, seed=0)
results = world.evaluate(
    episodes=100,
    seed=42,
    video="videos/",          # optional: mp4 per episode
)

print(f"Success rate: {results['success_rate']:.1f}%")
# One env per target episode. Each env starts at the chosen step and aims
# for the state `goal_offset` timesteps later. Run capped at `eval_budget`.
results = world.evaluate(
    dataset=dataset,
    episodes_idx=[0, 1, 2, 3],
    start_steps=[0, 10, 20, 30],
    goal_offset=30,
    eval_budget=50,
    video="videos/",
)

reset(options=...) accepts a list of per-env dicts to seed domain randomization or variations:

per_env = [
    {"variation": ["agent.color"], "variation_values": {"agent.color": [255, 0, 0]}},
    {"variation": ["agent.color"], "variation_values": {"agent.color": [0, 255, 0]}},
    {"variation": ["agent.color"], "variation_values": {"agent.color": [0, 0, 255]}},
]
world.reset(options=per_env)

Info convention

Every tensor / array value in world.infos carries a leading time dim of 1 after the env dim:

world.infos["pixels"].shape  # (num_envs, 1, H, W, C)
world.infos["state"].shape   # (num_envs, 1, state_dim)

Non-array values (strings, nested objects) stay as a Python list of length num_envs. rewards, terminateds, and truncateds are returned from the last step() separately and are shape (num_envs,) — they do not carry the time dim.

Reset modes

evaluate (and internally _run) support two termination policies:

  • reset_mode='auto' — terminated envs are reset immediately. The run continues until episodes episodes have finished (or max_steps is reached). This is the default for episodic eval.
  • reset_mode='wait' — terminated envs are frozen (step is skipped for them via the env-pool mask). The run stops when all envs are done. This is the default for dataset eval, so every env gets to complete its specific start→goal task.

World

World(
    env_name: str,
    num_envs: int,
    image_shape: tuple[int, int],
    max_episode_steps: int = 100,
    goal_conditioned: bool = True,
    extra_wrappers: list | None = None,
    image_transform: Callable | None = None,
    goal_transform: Callable | None = None,
    image_resample: str | int | None = None,
    **kwargs: Any,
)

Drive a policy through a pool of preprocessed envs.

After construction, world.envs is an EnvPool of num_envs environments, each wrapped by MegaWrapper (and any extra_wrappers you pass). Attach a policy with set_policy(...) and then call collect() or evaluate() to run rollouts.

Attributes populated during a run

infos: Stacked info dict from the last reset/step. Tensor/array values have shape (num_envs, 1, ...). rewards, terminateds, truncateds: Per-env step outputs from the last step(). Shape (num_envs,).

Parameters:

  • env_name (str) –

    Gymnasium id registered for the target env (e.g. 'swm/PushT-v1').

  • num_envs (int) –

    Number of parallel envs in the pool.

  • image_shape (tuple[int, int]) –

    (H, W) that pixels/goal are resized to.

  • max_episode_steps (int, default: 100 ) –

    Per-env step cap before truncation.

  • goal_conditioned (bool, default: True ) –

    If True, the goal key is kept separate from regular observations (controls MegaWrapper.separate_goal).

  • extra_wrappers (list | None, default: None ) –

    Additional gym.Wrapper factories applied after MegaWrapper.

  • image_transform (Callable | None, default: None ) –

    Optional callable applied to pixels inside MegaWrapper.

  • goal_transform (Callable | None, default: None ) –

    Optional callable applied to the goal inside MegaWrapper.

  • image_resample (str | int | None, default: None ) –

    PIL resample mode for pixel/goal resizing ('nearest', 'bilinear', ...). Defaults to bilinear; use 'nearest' for crisp pixel-art envs (e.g. Craftax).

  • **kwargs (Any, default: {} ) –

    Forwarded to gym.make (e.g. render_mode).

[ Rollouts ]

collect

collect(
    path: str | Path,
    episodes: int,
    seed: int | None = None,
    options: dict | None = None,
    format: str = 'lance',
) -> None

Roll out episodes and dump their trajectories using the writer registered for format.

Each info key becomes a column. Leading length-1 time dims are squeezed. Columns starting with _ (e.g. _needs_flush) are skipped.

Parameters:

  • path (str | Path) –

    Output path (file or directory, depending on the format). Parent dirs are auto-created.

  • episodes (int) –

    Number of episodes to record.

  • seed (int | None, default: None ) –

    Base seed for env resets.

  • options (dict | None, default: None ) –

    Reset options forwarded to envs.reset.

  • format (str, default: 'lance' ) –

    Registered format name (default 'lance'). See :func:stable_worldmodel.data.list_formats for available writers; new formats can be added via :func:stable_worldmodel.data.register_format.

evaluate

evaluate(
    episodes: int | None = None,
    seed: int | None = None,
    options: dict | None = None,
    video: str | Path | None = None,
    reset_mode: str | None = None,
    dataset: Any = None,
    episodes_idx: list[int] | None = None,
    start_steps: list[int] | None = None,
    goal_offset: int | None = None,
    eval_budget: int | None = None,
    callables: list[dict] | None = None,
) -> dict

Run the attached policy and return aggregated metrics.

Two modes of operation:

  • Episodic (default): set episodes to the number of episodes to roll out. Terminated envs are auto-reset until the target count is reached.

  • Dataset-driven: pass dataset with episodes_idx / start_steps / goal_offset / eval_budget. Each env is seeded from one dataset episode, starts at start_steps[i] and targets the state at start_steps[i] + goal_offset. Run length is capped at eval_budget steps. Requires num_envs == len(episodes_idx).

Parameters:

  • episodes (int | None, default: None ) –

    Total episodes to roll out (episodic mode).

  • seed (int | None, default: None ) –

    Base seed. Per-env seeds are derived by offsetting it.

  • options (dict | None, default: None ) –

    Reset options forwarded to envs.reset.

  • video (str | Path | None, default: None ) –

    Directory to write one mp4 per episode/env (optional).

  • reset_mode (str | None, default: None ) –

    'auto' (reset terminated envs) or 'wait' (freeze terminated envs and stop when all are done). Defaults to 'auto' for episodic eval and 'wait' for dataset eval.

  • dataset (Any, default: None ) –

    Source dataset for dataset-driven eval.

  • episodes_idx (list[int] | None, default: None ) –

    Dataset episode indices, one per env.

  • start_steps (list[int] | None, default: None ) –

    Starting step within each dataset episode.

  • goal_offset (int | None, default: None ) –

    Offset from each start step that defines the goal.

  • eval_budget (int | None, default: None ) –

    Max env steps per episode in dataset mode.

  • callables (list[dict] | None, default: None ) –

    Per-env setup calls applied on the unwrapped env after reset. Each spec is {'method': name, 'args': {arg_name: {'value': ..., 'in_dataset': bool}}}; if in_dataset is True, the value names a key in the sliced dataset state and the per-env value is deep-copied in.

Returns:

  • dict

    A dict with 'success_rate' (percent), 'episode_successes'

  • dict

    (per-episode bool/uint array), and 'seeds' used for reset.

[ Environment ]

reset

reset(seed=None, options=None) -> None

Reset every env and refresh self.infos.

Clears terminateds/truncateds back to all-False.

set_policy

set_policy(policy: Policy) -> None

Attach a policy and configure it for this world's envs.

Calls policy.set_env(self.envs). If the policy exposes a seed attribute and set_seed method, the seed is applied.

close

close() -> None

Close all envs and release their resources.

[ Properties ]

num_envs property

num_envs: int

Number of envs in the pool.

EnvPool

The underlying batched simulator. You rarely touch it directly — World builds one for you — but its action and observation spaces are what the policy sees.

EnvPool

EnvPool(env_fns: list)

Batched env runner with selective stepping.

Parameters:

  • env_fns (list) –

    List of zero-arg factories, one per env. Each is called once and the result is kept for the lifetime of the pool.

Methods:

  • reset

    Reset envs and return the stacked info dict.

  • step

    Step envs and return (None, rewards, terminateds, truncateds, infos).

  • close

    Close every env in the pool.

Attributes:

num_envs property

num_envs: int

Number of envs in the pool.

action_space property

action_space: Space

Batched action space (batch_space(single_action_space, num_envs)).

single_action_space property

single_action_space: Space

Action space of a single env.

observation_space property

observation_space: Space

Batched observation space.

single_observation_space property

single_observation_space: Space

Observation space of a single env.

variation_space property

variation_space

Variation space from the unwrapped env, or None if not defined.

single_variation_space property

single_variation_space

Variation space for a single env (alias of variation_space).

reset

reset(
    seed: int | list[int | None] | None = None,
    options: dict | list[dict | None] | None = None,
    mask: ndarray | None = None,
) -> tuple[None, dict]

Reset envs and return the stacked info dict.

Parameters:

  • seed (int | list[int | None] | None, default: None ) –

    Base int (each env gets seed + i), a per-env list, or None.

  • options (dict | list[dict | None] | None, default: None ) –

    Shared dict or per-env list.

  • mask (ndarray | None, default: None ) –

    If provided, only envs where mask[i] is truthy are reset. Others keep their current state in the stacked info buffer.

step

step(
    actions: ndarray, mask: ndarray | None = None
) -> tuple[None, ndarray, ndarray, ndarray, dict]

Step envs and return (None, rewards, terminateds, truncateds, infos).

Parameters:

  • actions (ndarray) –

    Array of shape (num_envs, ...) — one action per env.

  • mask (ndarray | None, default: None ) –

    If provided, only envs where mask[i] is truthy are stepped. Masked envs contribute zero reward and False termination/truncation, and their slot in the stacked info buffer is left unchanged.

close

close()

Close every env in the pool.