Policies determine the actions taken by agents in the environment. stable_worldmodel provides base classes and implementations for random, expert, and model-based policies.
A simple policy that samples actions uniformly from the environment's action space.
from stable_worldmodel.policy import RandomPolicy
# Create a random policy
policy = RandomPolicy(seed=42)
# Attach to a world/env later
# world.set_policy(policy)
A policy that uses a Solver (like CEM or MPPI) and a World Model to plan actions.
from stable_worldmodel.policy import WorldModelPolicy, PlanConfig
from stable_worldmodel.solver.random import RandomSolver
# 1. Define Planning Configuration
cfg = PlanConfig(
horizon=10,
receding_horizon=1,
action_block=1
)
# 2. Instantiate a Solver
solver = RandomSolver() # Or CEMSolver, MPPI, etc.
# 3. Create the Policy
policy = WorldModelPolicy(
solver=solver,
config=cfg
)
A policy that uses a neural network model for direct action prediction via a single forward pass. Useful for imitation learning policies like Goal-Conditioned Behavioral Cloning (GCBC).
from stable_worldmodel.policy import FeedForwardPolicy, AutoActionableModel
# 1. Load a pre-trained model with get_action method
model = AutoActionableModel("path/to/checkpoint")
# 2. Create the Policy
policy = FeedForwardPolicy(
model=model,
process={"action": action_scaler}, # Optional preprocessors
transform={"pixels": image_transform} # Optional transforms
)
Protocol
All policies must implement the get_action(obs, **kwargs) method. The World class automatically calls set_env() when a policy is attached.
PlanConfig
dataclass
PlanConfig(
horizon: int,
receding_horizon: int,
history_len: int = 1,
action_block: int = 1,
warm_start: bool = True,
)
Configuration for the MPC planning loop.
Attributes:
-
horizon(int) –Planning horizon in number of steps.
-
receding_horizon(int) –Number of steps to execute before re-planning.
-
history_len(int) –Number of past observations to consider.
-
action_block(int) –Number of times each action is repeated (frameskip).
-
warm_start(bool) –Whether to use the previous plan to initialize the next one.
BasePolicy
BasePolicy(**kwargs: Any)
get_action
Get action from the policy given the observation.
Parameters:
-
obs(Any) –The current observation from the environment.
-
**kwargs(Any, default:{}) –Additional parameters for action selection.
Returns:
-
ndarray–Selected action as a numpy array.
Raises:
-
NotImplementedError–If not implemented by a subclass.
set_env
set_env(env: Any) -> None
_prepare_info
Pre-process and transform observations.
Applies preprocessing (via self.process) and transformations (via self.transform)
to observation data. Used by subclasses like FeedForwardPolicy and WorldModelPolicy.
Parameters:
-
info_dict(dict) –Raw observation dictionary from the environment.
Returns:
Raises:
-
ValueError–If an expected numpy array is missing for processing.
RandomPolicy
Bases: BasePolicy
Policy that samples random actions from the action space.
Parameters:
-
seed(int | None, default:None) –Optional random seed for the action space.
-
**kwargs(Any, default:{}) –Additional configuration parameters.
get_action
ExpertPolicy
ExpertPolicy(**kwargs: Any)
Bases: BasePolicy
Policy using expert demonstrations or heuristics.
Parameters:
-
**kwargs(Any, default:{}) –Additional configuration parameters.
get_action
FeedForwardPolicy
FeedForwardPolicy(
model: Actionable,
process: dict[str, Transformable] | None = None,
transform: dict[str, Callable[[Tensor], Tensor]]
| None = None,
**kwargs: Any,
)
Bases: BasePolicy
Feed-Forward Policy using a neural network model.
Actions are computed via a single forward pass through the model. Useful for imitation learning policies like Goal-Conditioned Behavioral Cloning (GCBC).
Attributes:
-
model–Neural network model implementing the Actionable protocol.
-
process–Dictionary of data preprocessors for specific keys.
-
transform–Dictionary of tensor transformations (e.g., image transforms).
Parameters:
-
model(Actionable) –Neural network model with a
get_actionmethod. -
process(dict[str, Transformable] | None, default:None) –Dictionary of data preprocessors for specific keys.
-
transform(dict[str, Callable[[Tensor], Tensor]] | None, default:None) –Dictionary of tensor transformations (e.g., image transforms).
-
**kwargs(Any, default:{}) –Additional configuration parameters.
get_action
Get action via a forward pass through the neural network model.
Parameters:
-
info_dict(dict) –Current state information containing at minimum a 'goal' key.
-
**kwargs(Any, default:{}) –Additional parameters (unused).
Returns:
-
ndarray–The selected action as a numpy array.
Raises:
-
AssertionError–If environment not set or 'goal' not in info_dict.
WorldModelPolicy
WorldModelPolicy(
solver: Solver,
config: PlanConfig,
process: dict[str, Transformable] | None = None,
transform: dict[str, Callable[[Tensor], Tensor]]
| None = None,
**kwargs: Any,
)
Bases: BasePolicy
Policy using a world model and planning solver for action selection.
Parameters:
-
solver(Solver) –The planning solver to use.
-
config(PlanConfig) –MPC planning configuration.
-
process(dict[str, Transformable] | None, default:None) –Dictionary of data preprocessors for specific keys.
-
transform(dict[str, Callable[[Tensor], Tensor]] | None, default:None) –Dictionary of tensor transformations (e.g., image transforms).
-
**kwargs(Any, default:{}) –Additional configuration parameters.
get_action
[ Utils ]
AutoActionableModel
Load a model checkpoint and return the module with a get_action method.
Automatically scans the checkpoint for a module implementing the Actionable
protocol (i.e., has a get_action method).
Parameters:
-
run_name(str) –Path or name of the model run/checkpoint.
-
cache_dir(str | Path | None, default:None) –Optional cache directory path. Defaults to STABLEWM_HOME.
Returns:
-
Module–The module with a
get_actionmethod, set to eval mode.
Raises:
-
RuntimeError–If no module with
get_actionis found in the checkpoint.
AutoCostModel
Load a model checkpoint and return the module with a get_cost method.
Automatically scans the checkpoint for a module implementing a cost function
(i.e., has a get_cost method) for use with planning solvers.
Parameters:
-
run_name(str) –Path or name of the model run/checkpoint.
-
cache_dir(str | Path | None, default:None) –Optional cache directory path. Defaults to STABLEWM_HOME.
Returns:
-
Module–The module with a
get_costmethod, set to eval mode.
Raises:
-
RuntimeError–If no module with
get_costis found in the checkpoint.