Runs a policy against a pool of vectorized environments, with HDF5 data collection and dataset-driven evaluation.
World is the main entry point for rolling out policies in stable_worldmodel. It bundles:
- A batched simulator (
EnvPool) that stepsnum_envsenvs in parallel and can skip terminated envs via a mask. - A preprocessing pipeline (
MegaWrapper) that resizes pixels, lifts everything into the info dict, and applies optional transforms. - A rollout loop that drives
policy.get_action(infos)and handles resets, per-env termination, and episode accounting.
import stable_worldmodel as swm
from stable_worldmodel.policy import RandomPolicy
world = swm.World(
env_name="swm/PushT-v1",
num_envs=4,
image_shape=(64, 64),
)
world.set_policy(RandomPolicy())
# All stacked tensors in world.infos have shape (num_envs, 1, ...).
world.reset(seed=0)
# world.infos["pixels"] -> (4, 1, 64, 64, 3)
import stable_worldmodel as swm
world = swm.World("swm/PushT-v1", num_envs=8, image_shape=(64, 64))
world.set_policy(expert_policy)
# Roll out 500 episodes in parallel and dump them to an HDF5 file.
world.collect("data/pusht_expert.h5", episodes=500, seed=0)
results = world.evaluate(
episodes=100,
seed=42,
video="videos/", # optional: mp4 per episode
)
print(f"Success rate: {results['success_rate']:.1f}%")
# One env per target episode. Each env starts at the chosen step and aims
# for the state `goal_offset` timesteps later. Run capped at `eval_budget`.
results = world.evaluate(
dataset=dataset,
episodes_idx=[0, 1, 2, 3],
start_steps=[0, 10, 20, 30],
goal_offset=30,
eval_budget=50,
video="videos/",
)
reset(options=...) accepts a list of per-env dicts to seed domain randomization or variations:
per_env = [
{"variation": ["agent.color"], "variation_values": {"agent.color": [255, 0, 0]}},
{"variation": ["agent.color"], "variation_values": {"agent.color": [0, 255, 0]}},
{"variation": ["agent.color"], "variation_values": {"agent.color": [0, 0, 255]}},
]
world.reset(options=per_env)
Info convention
Every tensor / array value in world.infos carries a leading time dim of 1 after the env dim:
world.infos["pixels"].shape # (num_envs, 1, H, W, C)
world.infos["state"].shape # (num_envs, 1, state_dim)
Non-array values (strings, nested objects) stay as a Python list of length num_envs. rewards, terminateds, and truncateds are returned from the last step() separately and are shape (num_envs,) — they do not carry the time dim.
Reset modes
evaluate (and internally _run) support two termination policies:
reset_mode='auto'— terminated envs are reset immediately. The run continues untilepisodesepisodes have finished (ormax_stepsis reached). This is the default for episodic eval.reset_mode='wait'— terminated envs are frozen (step is skipped for them via the env-pool mask). The run stops when all envs are done. This is the default for dataset eval, so every env gets to complete its specific start→goal task.
World
World(
env_name: str,
num_envs: int,
image_shape: tuple[int, int],
max_episode_steps: int = 100,
goal_conditioned: bool = True,
extra_wrappers: list | None = None,
image_transform: Callable | None = None,
goal_transform: Callable | None = None,
image_resample: str | int | None = None,
**kwargs: Any,
)
Drive a policy through a pool of preprocessed envs.
After construction, world.envs is an EnvPool of num_envs
environments, each wrapped by MegaWrapper (and any extra_wrappers
you pass). Attach a policy with set_policy(...) and then call
collect() or evaluate() to run rollouts.
Attributes populated during a run
infos: Stacked info dict from the last reset/step. Tensor/array
values have shape (num_envs, 1, ...).
rewards, terminateds, truncateds: Per-env step outputs from the
last step(). Shape (num_envs,).
Parameters:
-
env_name(str) –Gymnasium id registered for the target env (e.g.
'swm/PushT-v1'). -
num_envs(int) –Number of parallel envs in the pool.
-
image_shape(tuple[int, int]) –(H, W)that pixels/goal are resized to. -
max_episode_steps(int, default:100) –Per-env step cap before truncation.
-
goal_conditioned(bool, default:True) –If True, the goal key is kept separate from regular observations (controls
MegaWrapper.separate_goal). -
extra_wrappers(list | None, default:None) –Additional
gym.Wrapperfactories applied afterMegaWrapper. -
image_transform(Callable | None, default:None) –Optional callable applied to pixels inside
MegaWrapper. -
goal_transform(Callable | None, default:None) –Optional callable applied to the goal inside
MegaWrapper. -
image_resample(str | int | None, default:None) –PIL resample mode for pixel/goal resizing (
'nearest','bilinear', ...). Defaults to bilinear; use'nearest'for crisp pixel-art envs (e.g. Craftax). -
**kwargs(Any, default:{}) –Forwarded to
gym.make(e.g.render_mode).
[ Rollouts ]
collect
collect(
path: str | Path,
episodes: int,
seed: int | None = None,
options: dict | None = None,
format: str = 'lance',
) -> None
Roll out episodes and dump their trajectories using the writer
registered for format.
Each info key becomes a column. Leading length-1 time dims are
squeezed. Columns starting with _ (e.g. _needs_flush)
are skipped.
Parameters:
-
path(str | Path) –Output path (file or directory, depending on the format). Parent dirs are auto-created.
-
episodes(int) –Number of episodes to record.
-
seed(int | None, default:None) –Base seed for env resets.
-
options(dict | None, default:None) –Reset options forwarded to
envs.reset. -
format(str, default:'lance') –Registered format name (default
'lance'). See :func:stable_worldmodel.data.list_formatsfor available writers; new formats can be added via :func:stable_worldmodel.data.register_format.
evaluate
evaluate(
episodes: int | None = None,
seed: int | None = None,
options: dict | None = None,
video: str | Path | None = None,
reset_mode: str | None = None,
dataset: Any = None,
episodes_idx: list[int] | None = None,
start_steps: list[int] | None = None,
goal_offset: int | None = None,
eval_budget: int | None = None,
callables: list[dict] | None = None,
) -> dict
Run the attached policy and return aggregated metrics.
Two modes of operation:
-
Episodic (default): set
episodesto the number of episodes to roll out. Terminated envs are auto-reset until the target count is reached. -
Dataset-driven: pass
datasetwithepisodes_idx/start_steps/goal_offset/eval_budget. Each env is seeded from one dataset episode, starts atstart_steps[i]and targets the state atstart_steps[i] + goal_offset. Run length is capped ateval_budgetsteps. Requiresnum_envs == len(episodes_idx).
Parameters:
-
episodes(int | None, default:None) –Total episodes to roll out (episodic mode).
-
seed(int | None, default:None) –Base seed. Per-env seeds are derived by offsetting it.
-
options(dict | None, default:None) –Reset options forwarded to
envs.reset. -
video(str | Path | None, default:None) –Directory to write one mp4 per episode/env (optional).
-
reset_mode(str | None, default:None) –'auto'(reset terminated envs) or'wait'(freeze terminated envs and stop when all are done). Defaults to'auto'for episodic eval and'wait'for dataset eval. -
dataset(Any, default:None) –Source dataset for dataset-driven eval.
-
episodes_idx(list[int] | None, default:None) –Dataset episode indices, one per env.
-
start_steps(list[int] | None, default:None) –Starting step within each dataset episode.
-
goal_offset(int | None, default:None) –Offset from each start step that defines the goal.
-
eval_budget(int | None, default:None) –Max env steps per episode in dataset mode.
-
callables(list[dict] | None, default:None) –Per-env setup calls applied on the unwrapped env after reset. Each spec is
{'method': name, 'args': {arg_name: {'value': ..., 'in_dataset': bool}}}; ifin_datasetis True, thevaluenames a key in the sliced dataset state and the per-env value is deep-copied in.
Returns:
[ Environment ]
reset
reset(seed=None, options=None) -> None
Reset every env and refresh self.infos.
Clears terminateds/truncateds back to all-False.
set_policy
set_policy(policy: Policy) -> None
Attach a policy and configure it for this world's envs.
Calls policy.set_env(self.envs). If the policy exposes a
seed attribute and set_seed method, the seed is applied.
close
close() -> None
Close all envs and release their resources.
[ Properties ]
EnvPool
The underlying batched simulator. You rarely touch it directly — World builds one for you — but its action and observation spaces are what the policy sees.
EnvPool
EnvPool(env_fns: list)
Batched env runner with selective stepping.
Parameters:
-
env_fns(list) –List of zero-arg factories, one per env. Each is called once and the result is kept for the lifetime of the pool.
Methods:
-
reset–Reset envs and return the stacked info dict.
-
step–Step envs and return
(None, rewards, terminateds, truncateds, infos). -
close–Close every env in the pool.
Attributes:
-
num_envs(int) –Number of envs in the pool.
-
action_space(Space) –Batched action space (
batch_space(single_action_space, num_envs)). -
single_action_space(Space) –Action space of a single env.
-
observation_space(Space) –Batched observation space.
-
single_observation_space(Space) –Observation space of a single env.
-
variation_space–Variation space from the unwrapped env, or
Noneif not defined. -
single_variation_space–Variation space for a single env (alias of
variation_space).
action_space
property
action_space: Space
Batched action space (batch_space(single_action_space, num_envs)).
single_action_space
property
single_action_space: Space
Action space of a single env.
observation_space
property
observation_space: Space
Batched observation space.
single_observation_space
property
single_observation_space: Space
Observation space of a single env.
variation_space
property
variation_space
Variation space from the unwrapped env, or None if not defined.
single_variation_space
property
single_variation_space
Variation space for a single env (alias of variation_space).
reset
reset(
seed: int | list[int | None] | None = None,
options: dict | list[dict | None] | None = None,
mask: ndarray | None = None,
) -> tuple[None, dict]
Reset envs and return the stacked info dict.
Parameters:
-
seed(int | list[int | None] | None, default:None) –Base int (each env gets
seed + i), a per-env list, orNone. -
options(dict | list[dict | None] | None, default:None) –Shared dict or per-env list.
-
mask(ndarray | None, default:None) –If provided, only envs where
mask[i]is truthy are reset. Others keep their current state in the stacked info buffer.
step
step(
actions: ndarray, mask: ndarray | None = None
) -> tuple[None, ndarray, ndarray, ndarray, dict]
Step envs and return (None, rewards, terminateds, truncateds, infos).
Parameters:
-
actions(ndarray) –Array of shape
(num_envs, ...)— one action per env. -
mask(ndarray | None, default:None) –If provided, only envs where
mask[i]is truthy are stepped. Masked envs contribute zero reward andFalsetermination/truncation, and their slot in the stacked info buffer is left unchanged.
close
close()
Close every env in the pool.