Agent policies for interacting with environments


Policies determine the actions taken by agents in the environment. stable_worldmodel provides base classes and implementations for random, expert, and model-based policies.

A simple policy that samples actions uniformly from the environment's action space.

from stable_worldmodel.policy import RandomPolicy

# Create a random policy
policy = RandomPolicy(seed=42)

# Attach to a world/env later
# world.set_policy(policy)

A policy that uses a Solver (like CEM or MPPI) and a World Model to plan actions.

from stable_worldmodel.policy import WorldModelPolicy, PlanConfig
from stable_worldmodel.solver.random import RandomSolver

# 1. Define Planning Configuration
cfg = PlanConfig(
    horizon=10,
    receding_horizon=1,
    action_block=1
)

# 2. Instantiate a Solver
solver = RandomSolver() # Or CEMSolver, MPPI, etc.

# 3. Create the Policy
policy = WorldModelPolicy(
    solver=solver,
    config=cfg
)

A policy that uses a neural network model for direct action prediction via a single forward pass. Useful for imitation learning policies like Goal-Conditioned Behavioral Cloning (GCBC).

from stable_worldmodel.policy import FeedForwardPolicy, AutoActionableModel

# 1. Load a pre-trained model with get_action method
model = AutoActionableModel("path/to/checkpoint")

# 2. Create the Policy
policy = FeedForwardPolicy(
    model=model,
    process={"action": action_scaler},  # Optional preprocessors
    transform={"pixels": image_transform}  # Optional transforms
)

Protocol

All policies must implement the get_action(obs, **kwargs) method. The World class automatically calls set_env() when a policy is attached.

PlanConfig dataclass

PlanConfig(
    horizon: int,
    receding_horizon: int,
    history_len: int = 1,
    action_block: int = 1,
    warm_start: bool = True,
)

Configuration for the MPC planning loop.

Attributes:

  • horizon (int) –

    Planning horizon in number of steps.

  • receding_horizon (int) –

    Number of steps to execute before re-planning.

  • history_len (int) –

    Number of past observations to consider.

  • action_block (int) –

    Number of times each action is repeated (frameskip).

  • warm_start (bool) –

    Whether to use the previous plan to initialize the next one.

BasePolicy

BasePolicy(**kwargs: Any)

Base class for agent policies.

Attributes:

  • env (Any) –

    The environment the policy is associated with.

  • type (str) –

    A string identifier for the policy type.

Parameters:

  • **kwargs (Any, default: {} ) –

    Additional configuration parameters.

get_action

get_action(obs: Any, **kwargs: Any) -> ndarray

Get action from the policy given the observation.

Parameters:

  • obs (Any) –

    The current observation from the environment.

  • **kwargs (Any, default: {} ) –

    Additional parameters for action selection.

Returns:

  • ndarray

    Selected action as a numpy array.

Raises:

set_env

set_env(env: Any) -> None

Associate this policy with an environment.

Parameters:

  • env (Any) –

    The environment to associate.

_prepare_info

_prepare_info(info_dict: dict) -> dict[str, Tensor]

Pre-process and transform observations.

Applies preprocessing (via self.process) and transformations (via self.transform) to observation data. Used by subclasses like FeedForwardPolicy and WorldModelPolicy.

Parameters:

  • info_dict (dict) –

    Raw observation dictionary from the environment.

Returns:

  • dict[str, Tensor]

    A dictionary of processed tensors.

Raises:

  • ValueError

    If an expected numpy array is missing for processing.

RandomPolicy

RandomPolicy(seed: int | None = None, **kwargs: Any)

Bases: BasePolicy

Policy that samples random actions from the action space.

Parameters:

  • seed (int | None, default: None ) –

    Optional random seed for the action space.

  • **kwargs (Any, default: {} ) –

    Additional configuration parameters.

get_action

get_action(obs: Any, **kwargs: Any) -> ndarray

Get a random action from the environment's action space.

Parameters:

  • obs (Any) –

    The current observation (ignored).

  • **kwargs (Any, default: {} ) –

    Additional parameters (ignored).

Returns:

  • ndarray

    A randomly sampled action.

ExpertPolicy

ExpertPolicy(**kwargs: Any)

Bases: BasePolicy

Policy using expert demonstrations or heuristics.

Parameters:

  • **kwargs (Any, default: {} ) –

    Additional configuration parameters.

get_action

get_action(
    obs: Any, goal_obs: Any, **kwargs: Any
) -> ndarray | None

Get action from the expert policy.

Parameters:

  • obs (Any) –

    The current observation.

  • goal_obs (Any) –

    The goal observation.

  • **kwargs (Any, default: {} ) –

    Additional parameters.

Returns:

  • ndarray | None

    The expert action, or None if not available.

FeedForwardPolicy

FeedForwardPolicy(
    model: Actionable,
    process: dict[str, Transformable] | None = None,
    transform: dict[str, Callable[[Tensor], Tensor]]
    | None = None,
    **kwargs: Any,
)

Bases: BasePolicy

Feed-Forward Policy using a neural network model.

Actions are computed via a single forward pass through the model. Useful for imitation learning policies like Goal-Conditioned Behavioral Cloning (GCBC).

Attributes:

  • model

    Neural network model implementing the Actionable protocol.

  • process

    Dictionary of data preprocessors for specific keys.

  • transform

    Dictionary of tensor transformations (e.g., image transforms).

Parameters:

  • model (Actionable) –

    Neural network model with a get_action method.

  • process (dict[str, Transformable] | None, default: None ) –

    Dictionary of data preprocessors for specific keys.

  • transform (dict[str, Callable[[Tensor], Tensor]] | None, default: None ) –

    Dictionary of tensor transformations (e.g., image transforms).

  • **kwargs (Any, default: {} ) –

    Additional configuration parameters.

get_action

get_action(info_dict: dict, **kwargs: Any) -> ndarray

Get action via a forward pass through the neural network model.

Parameters:

  • info_dict (dict) –

    Current state information containing at minimum a 'goal' key.

  • **kwargs (Any, default: {} ) –

    Additional parameters (unused).

Returns:

  • ndarray

    The selected action as a numpy array.

Raises:

  • AssertionError

    If environment not set or 'goal' not in info_dict.

WorldModelPolicy

WorldModelPolicy(
    solver: Solver,
    config: PlanConfig,
    process: dict[str, Transformable] | None = None,
    transform: dict[str, Callable[[Tensor], Tensor]]
    | None = None,
    **kwargs: Any,
)

Bases: BasePolicy

Policy using a world model and planning solver for action selection.

Parameters:

  • solver (Solver) –

    The planning solver to use.

  • config (PlanConfig) –

    MPC planning configuration.

  • process (dict[str, Transformable] | None, default: None ) –

    Dictionary of data preprocessors for specific keys.

  • transform (dict[str, Callable[[Tensor], Tensor]] | None, default: None ) –

    Dictionary of tensor transformations (e.g., image transforms).

  • **kwargs (Any, default: {} ) –

    Additional configuration parameters.

get_action

get_action(info_dict: dict, **kwargs: Any) -> ndarray

Get action via planning with the world model.

Parameters:

  • info_dict (dict) –

    Current state information from the environment.

  • **kwargs (Any, default: {} ) –

    Additional parameters for planning.

Returns:

  • ndarray

    The selected action(s) as a numpy array.

[ Utils ]

AutoActionableModel

AutoActionableModel(
    run_name: str, cache_dir: str | Path | None = None
) -> Module

Load a model checkpoint and return the module with a get_action method.

Automatically scans the checkpoint for a module implementing the Actionable protocol (i.e., has a get_action method).

Parameters:

  • run_name (str) –

    Path or name of the model run/checkpoint.

  • cache_dir (str | Path | None, default: None ) –

    Optional cache directory path. Defaults to STABLEWM_HOME.

Returns:

  • Module

    The module with a get_action method, set to eval mode.

Raises:

  • RuntimeError

    If no module with get_action is found in the checkpoint.

AutoCostModel

AutoCostModel(
    run_name: str, cache_dir: str | Path | None = None
) -> Module

Load a model checkpoint and return the module with a get_cost method.

Automatically scans the checkpoint for a module implementing a cost function (i.e., has a get_cost method) for use with planning solvers.

Parameters:

  • run_name (str) –

    Path or name of the model run/checkpoint.

  • cache_dir (str | Path | None, default: None ) –

    Optional cache directory path. Defaults to STABLEWM_HOME.

Returns:

  • Module

    The module with a get_cost method, set to eval mode.

Raises:

  • RuntimeError

    If no module with get_cost is found in the checkpoint.