A unified interface for orchestrating vectorized environments, managing policy interactions, and handling data collection (HDF5/Video) and evaluation pipelines.


The World class is the central entry point for managing vectorized environments in stable_worldmodel. It handles synchronization, preprocessing (resizing, stacking), and interaction with policies.

from stable_worldmodel import World
from stable_worldmodel.policy import RandomPolicy

# 1. Initialize the World with 4 parallel environments
world = World(
    env_name="swm/PushT-v1",
    num_envs=4,
    image_shape=(64, 64),
    history_size=1
)

# 2. Set a policy (e.g., Random)
world.set_policy(RandomPolicy())

# 3. Reset and step
world.reset()
for _ in range(100):
    world.step()
    # Access current states/infos
    # world.infos["pixels"] -> (4, 3, 64, 64)
from stable_worldmodel import World
from stable_worldmodel.policy import RandomPolicy

world = World(
    env_name="swm/PushT-v1",
    num_envs=1,
    image_shape=(64, 64)
)
world.set_policy(RandomPolicy())

# Record a 500-step video
world.record_video(
    video_path="./videos",
    max_steps=500,
    viewname="pixels"
)
from stable_worldmodel import World
from stable_worldmodel.policy import RandomPolicy

world = World(
    env_name="swm/PushT-v1",
    num_envs=4,  # Collect 4 episodes in parallel
    image_shape=(64, 64)
)
world.set_policy(RandomPolicy())

# Record 50 episodes to a .h5 dataset
world.record_dataset(
    dataset_name="pusht_random",
    episodes=50,
    cache_dir="./data"
)
# Result: ./data/pusht_random.h5
from stable_worldmodel import World
from stable_worldmodel.data import HDF5Dataset
from stable_worldmodel.policy import RandomPolicy # or your trained policy

# 1. Load a dataset for initial states
dataset = HDF5Dataset("pusht_random", cache_dir="./data")

# 2. Setup World
world = World(env_name="swm/PushT-v1", num_envs=4, image_shape=(64, 64))
world.set_policy(RandomPolicy())

# 3. Evaluate starting from dataset states
results = world.evaluate_from_dataset(
    dataset=dataset,
    episodes_idx=[0, 1, 2, 3],  # Episodes to test on
    start_steps=[0, 0, 0, 0],   # Start from beginning
    goal_offset_steps=50,       # Goal is state at t=50
    eval_budget=100             # Max steps to reach goal
)

print(f"Success Rate: {results['success_rate']}%")

Performance

The World class uses a custom SyncWorld vectorized environment for synchronized execution, ensuring deterministic and batched stepping across multiple environments.

The World class supports passing different options to each environment during reset, enabling per-environment variations and configurations:

from stable_worldmodel import World

world = World(
    env_name="swm/PushT-v1",
    num_envs=3,
    image_shape=(64, 64)
)

# Different variations for each environment
per_env_options = [
    {"variation": ["agent.color"], "variation_values": {"agent.color": [255, 0, 0]}},
    {"variation": ["agent.color"], "variation_values": {"agent.color": [0, 255, 0]}},
    {"variation": ["agent.color"], "variation_values": {"agent.color": [0, 0, 255]}},
]

world.reset(options=per_env_options)

This is useful for:

  • Domain randomization: Different visual variations per environment
  • Curriculum learning: Different difficulty levels per environment
  • Parallel evaluation: Testing multiple configurations simultaneously

World

World(
    env_name: str,
    num_envs: int,
    image_shape: tuple[int, int],
    goal_transform: Callable[[Any], Any] | None = None,
    image_transform: Callable[[Any], Any] | None = None,
    seed: int = 2349867,
    history_size: int = 1,
    frame_skip: int = 1,
    max_episode_steps: int = 100,
    verbose: int = 1,
    extra_wrappers: list[Callable] | None = None,
    goal_conditioned: bool = True,
    **kwargs: Any,
)

High-level manager for vectorized Gymnasium environments.

Manages a set of synchronized vectorized environments with automatic preprocessing (resizing, frame stacking, goal conditioning).

Parameters:

  • env_name (str) –

    Name of the Gymnasium environment to create.

  • num_envs (int) –

    Number of parallel environments.

  • image_shape (tuple[int, int]) –

    Target shape for image observations (H, W).

  • goal_transform (Callable[[Any], Any] | None, default: None ) –

    Optional callable to transform goal observations.

  • image_transform (Callable[[Any], Any] | None, default: None ) –

    Optional callable to transform image observations.

  • seed (int, default: 2349867 ) –

    Random seed for reproducibility.

  • history_size (int, default: 1 ) –

    Number of frames to stack.

  • frame_skip (int, default: 1 ) –

    Number of frames to skip per step.

  • max_episode_steps (int, default: 100 ) –

    Maximum steps per episode before truncation.

  • verbose (int, default: 1 ) –

    Verbosity level (0: silent, >0: info).

  • extra_wrappers (list[Callable] | None, default: None ) –

    List of additional wrappers to apply to each env.

  • goal_conditioned (bool, default: True ) –

    Whether to separate goal from observation.

  • **kwargs (Any, default: {} ) –

    Additional keyword arguments passed to gym.make_vec.

[ Recording ]

record_dataset

record_dataset(
    dataset_name: str,
    episodes: int = 10,
    seed: int | None = None,
    cache_dir: PathLike | str | None = None,
    options: dict | None = None,
) -> None

Records episodes from the environment into an HDF5 dataset.

Parameters:

  • dataset_name (str) –

    Name of the dataset file (without extension).

  • episodes (int, default: 10 ) –

    Total number of episodes to record.

  • seed (int | None, default: None ) –

    Initial random seed.

  • cache_dir (PathLike | str | None, default: None ) –

    Directory to save the dataset. Defaults to standard cache.

  • options (dict | None, default: None ) –

    Reset options passed to environments.

Raises:

record_video

record_video(
    video_path: str | Path,
    max_steps: int = 500,
    fps: int = 30,
    viewname: str | list[str] = 'pixels',
    seed: int | None = None,
    options: dict | None = None,
) -> None

Record rollout videos for each environment under the current policy.

Parameters:

  • video_path (str | Path) –

    Directory path to save the videos.

  • max_steps (int, default: 500 ) –

    Maximum steps to record per environment.

  • fps (int, default: 30 ) –

    Frames per second for the output video.

  • viewname (str | list[str], default: 'pixels' ) –

    Key(s) in infos containing image data to render.

  • seed (int | None, default: None ) –

    Random seed for reset.

  • options (dict | None, default: None ) –

    Options for reset.

[ Evaluation ]

evaluate_from_dataset

evaluate_from_dataset(
    dataset: Any,
    episodes_idx: Sequence[int],
    start_steps: Sequence[int],
    goal_offset_steps: int,
    eval_budget: int,
    callables: list[dict] | None = None,
    save_video: bool = True,
    video_path: str | Path = './',
) -> dict

Evaluate the policy starting from states sampled from a dataset.

Parameters:

  • dataset (Any) –

    The source dataset to sample initial states/goals from.

  • episodes_idx (Sequence[int]) –

    Indices of episodes to sample from.

  • start_steps (Sequence[int]) –

    Step indices within those episodes to start from.

  • goal_offset_steps (int) –

    Number of steps ahead to look for the goal.

  • eval_budget (int) –

    Maximum steps allowed for the agent to reach the goal.

  • callables (list[dict] | None, default: None ) –

    Optional list of method calls to setup the env.

  • save_video (bool, default: True ) –

    Whether to save rollout videos.

  • video_path (str | Path, default: './' ) –

    Path to save videos.

Returns:

  • dict

    Dictionary containing success rates and other metrics.

Raises:

  • ValueError

    If input sequence lengths mismatch or don't match num_envs.

evaluate

evaluate(
    episodes: int = 10,
    eval_keys: list[str] | None = None,
    seed: int | None = None,
    options: dict | None = None,
    dump_every: int = -1,
) -> dict

Evaluate the current policy over multiple episodes.

Parameters:

  • episodes (int, default: 10 ) –

    Number of episodes to evaluate.

  • eval_keys (list[str] | None, default: None ) –

    List of keys in infos to collect and return.

  • seed (int | None, default: None ) –

    Random seed for evaluation.

  • options (dict | None, default: None ) –

    Reset options.

  • dump_every (int, default: -1 ) –

    Interval to save intermediate results (for long evals).

Returns:

  • dict

    Dictionary containing success rates, seeds, and collected keys.

[ Environment ]

reset

reset(
    seed: int | list[int] | None = None,
    options: dict | None = None,
) -> None

Reset all environments to initial states.

Parameters:

  • seed (int | list[int] | None, default: None ) –

    Random seed(s) for the environments.

  • options (dict | None, default: None ) –

    Additional options passed to the environment reset.

step

step() -> None

Advance all environments by one step using the current policy.

close

close(**kwargs: Any) -> None

Close all environments and clean up resources.

set_policy

set_policy(policy: Policy) -> None

Attach a policy to the world.

Parameters:

  • policy (Policy) –

    The policy instance to use for determining actions.

[ Properties ]

num_envs property

num_envs: int

Number of parallel environment instances.

observation_space property

observation_space: Space

Batched observation space for all environments.

action_space property

action_space: Space

Batched action space for all environments.

variation_space property

variation_space: Space | None

Batched variation space for domain randomization.

single_variation_space property

single_variation_space: Space | None

Variation space for a single environment instance.

single_action_space property

single_action_space: Space

Action space for a single environment instance.

single_observation_space property

single_observation_space: Space

Observation space for a single environment instance.