[ Base Class ]
Solver
configure
solve
Solve the planning optimization problem to find optimal actions.
Parameters:
-
info_dict(dict) –Dictionary containing environment state information.
-
init_action(Tensor | None, default:None) –Optional initial action sequence to warm-start the solver.
Returns:
-
dict–Dictionary containing optimized actions and other solver-specific info.
[ Implementations ]
CEMSolver
CEMSolver(
model: Costable,
batch_size: int = 1,
num_samples: int = 300,
var_scale: float = 1,
n_steps: int = 30,
topk: int = 30,
device: str | device = 'cpu',
seed: int = 1234,
callbacks: list[Callback] | None = None,
)
Cross Entropy Method solver for action optimization.
Parameters:
-
model(Costable) –World model implementing the Costable protocol.
-
batch_size(int, default:1) –Number of environments to process in parallel.
-
num_samples(int, default:300) –Number of action candidates to sample per iteration.
-
var_scale(float, default:1) –Initial variance scale for the action distribution.
-
n_steps(int, default:30) –Number of CEM iterations.
-
topk(int, default:30) –Number of elite samples to keep for distribution update.
-
device(str | device, default:'cpu') –Device for tensor computations.
-
seed(int, default:1234) –Random seed for reproducibility.
configure
Configure the solver with environment specifications.
solve
Solve the planning problem using Cross Entropy Method.
ICEMSolver
ICEMSolver(
model: Costable,
batch_size: int = 1,
num_samples: int = 300,
var_scale: float = 1,
n_steps: int = 30,
topk: int = 30,
noise_beta: float = 2.0,
alpha: float = 0.1,
n_elite_keep: int = 5,
return_mean: bool = True,
device: str | device = 'cpu',
seed: int = 1234,
callbacks: list[Callback] | None = None,
)
Improved Cross Entropy Method (iCEM) solver with colored noise and elite retention. iCEM improves the sample efficiency over standard CEM and was introduced by [1] for real-time planning.
Parameters:
-
model(Costable) –World model implementing the Costable protocol.
-
batch_size(int, default:1) –Number of environments to process in parallel.
-
num_samples(int, default:300) –Number of action candidates to sample per iteration.
-
var_scale(float, default:1) –Initial variance scale for the action distribution.
-
n_steps(int, default:30) –Number of CEM iterations.
-
topk(int, default:30) –Number of elite samples to keep for distribution update.
-
noise_beta(float, default:2.0) –Colored noise exponent. 0 = white (standard CEM), >0 = more low-frequency noise.
-
alpha(float, default:0.1) –Momentum for mean/std EMA update.
-
n_elite_keep(int, default:5) –Number of elites carried from previous iteration.
-
return_mean(bool, default:True) –If False, return best single trajectory instead of mean.
-
device(str | device, default:'cpu') –Device for tensor computations.
-
seed(int, default:1234) –Random seed for reproducibility.
[1] C. Pinneri, S. Sawant, S. Blaes, J. Achterhold, J. Stueckler, M. Rolinek and G, Martius, Georg. "Sample-efficient Cross-Entropy Method for Real-time Planning". Conference on Robot Learning, 2020.
configure
Configure the solver with environment specifications.
solve
Solve the planning problem using improved Cross Entropy Method.
MPPISolver
MPPISolver(
model: Costable,
batch_size: int = 1,
num_samples: int = 300,
var_scale: float = 1.0,
n_steps: int = 30,
topk: int = 30,
temperature: float = 0.5,
device: str | device = 'cpu',
seed: int = 1234,
)
Model Predictive Path Integral solver for action optimization.
Parameters:
-
model(Costable) –World model implementing the Costable protocol.
-
batch_size(int, default:1) –Number of environments to process in parallel.
-
num_samples(int, default:300) –Number of action candidates to sample per iteration.
-
var_scale(float, default:1.0) –Initial variance scale for action noise.
-
n_steps(int, default:30) –Number of MPPI iterations.
-
topk(int, default:30) –Number of elite samples for weighted averaging.
-
temperature(float, default:0.5) –Temperature parameter for softmax weighting.
-
device(str | device, default:'cpu') –Device for tensor computations.
-
seed(int, default:1234) –Random seed for reproducibility.
configure
Configure the solver with environment specifications.
solve
Solve the planning problem using MPPI.
PredictiveSamplingSolver
PredictiveSamplingSolver(
model: Costable,
batch_size: int = 1,
num_samples: int = 300,
noise_scale: float = 1.0,
device: str | device = 'cpu',
seed: int = 1234,
)
Predictive Sampling solver for action optimization.
Parameters:
-
model(Costable) –World model implementing the Costable protocol.
-
batch_size(int, default:1) –Number of environments to process in parallel.
-
num_samples(int, default:300) –Number of action candidates to sample.
-
noise_scale(float, default:1.0) –Standard deviation of additive Gaussian noise.
-
device(str | device, default:'cpu') –Device for tensor computations.
-
seed(int, default:1234) –Random seed for reproducibility.
configure
Configure the solver with environment specifications.
solve
Solve the planning problem using Predictive Sampling.
GradientSolver
GradientSolver(
model: Costable,
n_steps: int,
batch_size: int | None = None,
var_scale: float = 1,
num_samples: int = 1,
action_noise: float = 0.0,
device: str | device = 'cpu',
seed: int = 1234,
optimizer_cls: type[Optimizer] = SGD,
optimizer_kwargs: dict | None = None,
grad_clip: float | None = None,
callbacks: list[Callback] | None = None,
)
Bases: Module
Gradient-based solver using backpropagation through the world model.
Parameters:
-
model(Costable) –World model implementing the Costable protocol.
-
n_steps(int) –Number of gradient descent iterations.
-
batch_size(int | None, default:None) –Number of environments to process in parallel.
-
var_scale(float, default:1) –Initial variance scale for action perturbations.
-
num_samples(int, default:1) –Number of action samples to optimize in parallel.
-
action_noise(float, default:0.0) –Noise added to actions during optimization.
-
device(str | device, default:'cpu') –Device for tensor computations.
-
seed(int, default:1234) –Random seed for reproducibility.
-
optimizer_cls(type[Optimizer], default:SGD) –PyTorch optimizer class to use.
-
optimizer_kwargs(dict | None, default:None) –Keyword arguments for the optimizer.
configure
Configure the solver with environment specifications.
solve
Solve the planning problem using gradient descent.
PGDSolver
PGDSolver(
model: Costable,
n_steps: int,
batch_size: int | None = None,
var_scale: float = 1,
num_samples: int = 1,
action_noise: float = 0.0,
device: str | device = 'cpu',
seed: int = 1234,
)
Bases: Module
Projected Gradient Descent solver for discrete action optimization.
Parameters:
-
model(Costable) –World model implementing the Costable protocol.
-
n_steps(int) –Number of gradient descent iterations.
-
batch_size(int | None, default:None) –Number of environments to process in parallel.
-
var_scale(float, default:1) –Initial variance scale for action perturbations.
-
num_samples(int, default:1) –Number of action samples to optimize in parallel.
-
action_noise(float, default:0.0) –Noise added to actions during optimization.
-
device(str | device, default:'cpu') –Device for tensor computations.
-
seed(int, default:1234) –Random seed for reproducibility.
configure
Configure the solver with environment specifications.
solve
Solve the planning problem using projected gradient descent.
CategoricalCEMSolver
CategoricalCEMSolver(
model: Costable,
batch_size: int = 1,
num_samples: int = 300,
n_steps: int = 30,
topk: int = 30,
smoothing: float = 0.0,
alpha: float = 0.0,
device: str | device = 'cpu',
seed: int = 1234,
callbacks: list[Callback] | None = None,
)
Cross Entropy Method solver for discrete action optimization.
Maintains a per-timestep categorical distribution over discrete actions, samples candidate trajectories via Gumbel-max, and refits the distribution from the top-K elites' empirical frequencies.
Parameters:
-
model(Costable) –World model implementing the Costable protocol.
-
batch_size(int, default:1) –Number of environments to process in parallel.
-
num_samples(int, default:300) –Number of action candidates to sample per iteration.
-
n_steps(int, default:30) –Number of CEM iterations.
-
topk(int, default:30) –Number of elite samples to keep for distribution update.
-
smoothing(float, default:0.0) –Laplace smoothing added to refit probs to avoid collapse.
-
alpha(float, default:0.0) –Momentum for probs EMA update (0 = full overwrite).
-
device(str | device, default:'cpu') –Device for tensor computations.
-
seed(int, default:1234) –Random seed for reproducibility.
-
callbacks(list[Callback] | None, default:None) –Optional list of callbacks.
configure
Configure the solver with environment specifications.
solve
Solve the planning problem using Categorical CEM.
init_action is accepted for API parity with other solvers but is
ignored; probs are always initialized uniform.
LagrangianSolver
LagrangianSolver(
model: Costable,
n_steps: int,
n_outer_steps: int = 5,
batch_size: int | None = None,
num_samples: int = 1,
var_scale: float = 1.0,
action_noise: float = 0.0,
rho_init: float = 1.0,
rho_max: float = 10000.0,
rho_scale: float = 2.0,
persist_multipliers: bool = True,
device: str | device = 'cpu',
seed: int = 1234,
optimizer_cls: type[Optimizer] = Adam,
optimizer_kwargs: dict | None = None,
)
Bases: Module
Lagrangian solver for stable world model.
get_cost returns the cost tensor (B, S). If the model also implements get_constraints, it should return the constraint violations (B, S, C), where C is the number of constraints. The constraint_cost should represent the cost of violating the constraints, where the constraint is satisfied when constraint_cost <= 0. The Lagrangian solver will optimize the following objective:
L = cost + sum_{i=1}^C lambda_i * constraint_cost_i + sum_{i=1}^C rho_i * max(0, constraint_cost_i)^2
If you want to use equality constraint, you can convert it to two inequality constraints. For example, if you want to enforce constraint_cost_i == 0, you can add two constraints: constraint_cost_i <= 0 and -constraint_cost_i <= 0.
Parameters:
-
model(Costable) –World model implementing the Costable protocol. Its get_cost() returns a plain cost tensor (B, S). If it also has get_constraints(), that method returns constraints of shape (B, S, C).
-
n_steps(int) –Number of gradient descent steps per outer iteration.
-
n_outer_steps(int, default:5) –Number of dual ascent (outer) iterations.
-
batch_size(int | None, default:None) –Number of environments to process in parallel.
-
num_samples(int, default:1) –Number of action samples to optimize in parallel.
-
var_scale(float, default:1.0) –Initial variance scale for action perturbations.
-
action_noise(float, default:0.0) –Noise added to actions during optimization.
-
rho_init(float, default:1.0) –Initial penalty coefficient for the quadratic constraint term.
-
rho_max(float, default:10000.0) –Maximum value of the penalty coefficient.
-
rho_scale(float, default:2.0) –Multiplicative growth factor for rho after each outer step.
-
persist_multipliers(bool, default:True) –Whether to warm-start Lagrange multipliers across solve() calls.
-
device(str | device, default:'cpu') –Device for tensor computations.
-
seed(int, default:1234) –Random seed for reproducibility.
-
optimizer_cls(type[Optimizer], default:Adam) –PyTorch optimizer class to use.
-
optimizer_kwargs(dict | None, default:None) –Keyword arguments for the optimizer.
configure
Configure the solver with environment specifications.
solve
Solve the planning problem using augmented Lagrangian gradient descent.
[ Callbacks ]
Solvers accept a callbacks=[...] list of Callback
objects. Each callback fires once per inner-loop step and accumulates a
per-batch buffer; final histories are returned in outputs['callbacks'],
keyed by cb.output_key (defaults to the class name).
from stable_worldmodel.solver import GradientSolver
from stable_worldmodel.solver.callbacks import (
BestCostRecorder, GradNormRecorder, ActionNormRecorder,
)
solver = GradientSolver(
model=model, n_steps=20, num_samples=8,
callbacks=[
BestCostRecorder(), # mean over envs (default)
GradNormRecorder(reduction='none'), # one entry per env
ActionNormRecorder(reduction='sum'),
],
)
solver.configure(action_space=action_space, n_envs=4, config=config)
out = solver.solve(info_dict)
# out['callbacks']['BestCostRecorder'] -> list[list[float]] (batches x steps)
# out['callbacks']['GradNormRecorder'] -> list[list[list[float]]]
Reduction modes
Every callback accepts reduction ∈ {'mean', 'sum', 'none'}. Reduction is
applied across the env axis only; within-sample reductions (e.g. min over
samples for BestCostRecorder) are intrinsic to each metric.
| Mode | Output per step |
|---|---|
'mean' |
scalar (default) |
'sum' |
scalar |
'none' |
list[float] — one value per env in batch |
Available callbacks
| Callback | Solver(s) | Records |
|---|---|---|
BestCostRecorder |
any | min cost over samples |
MeanCostRecorder |
any | mean cost over samples |
GradNormRecorder |
GD | L2 norm of action gradient (optional per_step for per-horizon-step values) |
ActionNormRecorder |
GD | L2 norm of action tensor |
EliteCostRecorder |
CEM, iCEM | dict of elite cost stats (mean/min/max) |
VarNormRecorder |
CEM, iCEM | mean variance of action distribution |
MeanShiftRecorder |
CEM, iCEM | L2 distance between consecutive means |
EliteSpreadRecorder |
CEM, iCEM | within-elite std (top-k diversity) |
Writing a custom callback
Subclass Callback and
implement compute(**state). Pull the tensors you need from state and
call self._reduce(per_env_tensor) to honour the reduction mode.
from stable_worldmodel.solver.callbacks import Callback
class CostRangeRecorder(Callback):
"""Records per-env (max - min) cost across the sample population."""
def compute(self, **state):
costs = state['costs'].detach() # (B, N)
per_env = costs.max(dim=1).values - costs.min(dim=1).values
return self._reduce(per_env)
State keys passed by each solver:
- GD:
step,params,cost,costs - CEM:
step,candidates,costs,topk_vals,topk_inds,topk_candidates,mean,var,prev_mean,prev_var - iCEM: same as CEM plus
action_low,action_high - CategoricalCEM:
step,candidates,costs,topk_vals,topk_inds,topk_candidates,probs,prev_probs
Callback
Callback(reduction: Reduction = 'mean')
Base class for solver iteration callbacks.
Subclasses compute a per-env metric and call self._reduce(...) to
apply the configured reduction across envs. history is
list[list[Any]] (batches x steps), matching the shape of
outputs['cost'] in the gradient solver.
Methods:
-
reset– -
start_batch– -
end_solve– -
compute–
Attributes:
-
output_key(str) –
BestCostRecorder
BestCostRecorder(reduction: Reduction = 'mean')
MeanCostRecorder
MeanCostRecorder(reduction: Reduction = 'mean')
GradNormRecorder
GradNormRecorder(
reduction: Reduction = 'mean', per_step: bool = False
)
Bases: Callback
Per-step L2 norm of the action gradient (per env, mean over samples).
With per_step=True, returns a list of length H with one grad
norm per horizon step instead of reducing over the horizon dim.
ActionNormRecorder
ActionNormRecorder(reduction: Reduction = 'mean')
EliteCostRecorder
EliteCostRecorder(reduction: Reduction = 'mean')
VarNormRecorder
VarNormRecorder(reduction: Reduction = 'mean')
MeanShiftRecorder
MeanShiftRecorder(reduction: Reduction = 'mean')
EliteSpreadRecorder
EliteSpreadRecorder(reduction: Reduction = 'mean')
[ Example: Constrained Planning with LagrangianSolver ]
The LagrangianSolver extends gradient-based planning to handle inequality
constraints of the form g(a) ≤ 0. It uses the augmented Lagrangian method:
dual variables (λ) are maintained per environment and updated via dual ascent
after each inner optimisation loop, while a quadratic penalty term (controlled
by rho) enforces feasibility.
import dataclasses
import torch
import gymnasium as gym
import numpy as np
from stable_worldmodel.solver import LagrangianSolver
from stable_worldmodel.policy import PlanConfig
# ── 1. Define a world model with cost and optional constraints ──────────────
class MyModel(torch.nn.Module):
"""Minimal example: cost is MSE to a goal; two inequality constraints."""
def get_cost(self, info_dict, action_candidates):
# action_candidates: (B, S, H, D)
# returns: (B, S)
goal = torch.zeros(action_candidates.shape[-1])
return (action_candidates.mean(dim=2) - goal).pow(2).mean(dim=-1)
def get_constraints(self, info_dict, action_candidates):
# returns: (B, S, C) — violated when > 0
# g0: action L2 norm <= 1
g0 = action_candidates.norm(dim=-1).mean(dim=2) - 1.0
# g1: first action dimension <= 0.5
g1 = action_candidates[..., 0].mean(dim=2) - 0.5
return torch.stack([g0, g1], dim=-1)
# ── 2. Build and configure the solver ──────────────────────────────────────
model = MyModel()
solver = LagrangianSolver(
model=model,
n_steps=30, # inner gradient steps per outer iteration
n_outer_steps=10, # dual-ascent (outer) iterations
num_samples=8, # parallel action candidates per env
rho_init=1.0, # initial quadratic penalty coefficient
rho_scale=2.0, # rho doubles each outer step
rho_max=1e4,
persist_multipliers=True, # warm-start λ across planning calls
optimizer_kwargs={"lr": 0.05},
)
action_space = gym.spaces.Box(low=-np.inf, high=np.inf,
shape=(1, 4), dtype=np.float32)
config = PlanConfig(horizon=10, receding_horizon=1, action_block=1)
solver.configure(action_space=action_space, n_envs=2, config=config)
# ── 3. Solve ────────────────────────────────────────────────────────────────
info_dict = {"obs": torch.zeros(2, 4)} # current env observations
out = solver.solve(info_dict)
print(out["actions"].shape) # (2, 10, 4) — best action per env
print(out["lambdas"]) # (2, 2) — dual variables
print(out["constraint_violation"]) # mean ReLU(g) across samples
# ── 4. Receding-horizon planning (warm start) ───────────────────────────────
# Execute the first step, shift the plan, re-plan
executed_steps = 1
remaining = out["actions"][:, executed_steps:, :] # (2, 9, 4)
out2 = solver.solve(info_dict, init_action=remaining)
Key parameters
| Parameter | Default | Description |
|---|---|---|
n_steps |
— | Inner gradient steps per outer iteration |
n_outer_steps |
5 |
Dual-ascent iterations |
rho_init |
1.0 |
Initial quadratic penalty weight |
rho_scale |
2.0 |
Multiplicative growth for rho each outer step |
rho_max |
1e4 |
Upper bound on rho |
persist_multipliers |
True |
Keep λ across solve() calls (warm start) |
num_samples |
1 |
Parallel candidate trajectories per environment |
action_noise |
0.0 |
Gaussian noise injected each inner step |
Constraint protocol
Your model must implement get_constraints(info_dict, action_candidates) -> Tensor
returning shape (B, S, C). A constraint is satisfied when its value is ≤ 0.
To enforce an equality h(a) = 0, add two constraints: h(a) ≤ 0 and
-h(a) ≤ 0.
[ Example: Discrete Planning with CategoricalCEMSolver ]
CategoricalCEMSolver is the discrete-action analogue of CEMSolver. Instead
of fitting a Gaussian per timestep, it maintains a categorical distribution
over the Discrete(K) action space and refits it from the empirical
frequencies of top-K elite trajectories. Sampling uses the Gumbel-max trick
(seeded via the solver's torch.Generator) and candidates are passed to
model.get_cost as one-hot tensors — the same layout used by PGDSolver, so
discrete world models work unchanged.
import torch
import gymnasium as gym
from stable_worldmodel.solver import CategoricalCEMSolver
from stable_worldmodel.policy import PlanConfig
# ── 1. World model: cost defined over one-hot candidates ────────────────────
class DiscreteModel(torch.nn.Module):
"""Cost is minimized by selecting category 2 at every position."""
def get_cost(self, info_dict, action_candidates):
# action_candidates: (B, N, H, action_block * K) one-hot floats
# returns: (B, N)
K = 4
ab = action_candidates.shape[-1] // K
c = action_candidates.reshape(*action_candidates.shape[:-1], ab, K)
return -c[..., 2].sum(dim=(-1, -2))
# ── 2. Build and configure the solver ──────────────────────────────────────
solver = CategoricalCEMSolver(
model=DiscreteModel(),
n_steps=20, # CEM iterations
num_samples=128, # candidates per iteration
topk=16, # elite count
smoothing=0.01, # Laplace floor — prevents premature collapse
alpha=0.1, # EMA momentum on probs (0 = full overwrite)
seed=0,
)
action_space = gym.spaces.Discrete(4)
config = PlanConfig(horizon=8, receding_horizon=4, action_block=1)
solver.configure(action_space=action_space, n_envs=2, config=config)
# ── 3. Solve ───────────────────────────────────────────────────────────────
info_dict = {"obs": torch.zeros(2, 4)}
out = solver.solve(info_dict)
print(out["actions"].shape) # (2, 8, 1) — discrete indices, argmax of probs
print(out["probs"][0].shape) # (2, 8, 1, 4) — final categorical distribution
print(out["costs"]) # mean elite cost per env
Key parameters
| Parameter | Default | Description |
|---|---|---|
num_samples |
300 |
Candidate trajectories sampled per iteration |
n_steps |
30 |
CEM iterations |
topk |
30 |
Elite count for refit |
smoothing |
0.0 |
Laplace smoothing on refit probs (avoids collapse) |
alpha |
0.0 |
EMA momentum: probs ← α · prev + (1−α) · new |
batch_size |
1 |
Envs processed per outer batch |
Output layout
| Key | Shape | Meaning |
|---|---|---|
actions |
(n_envs, horizon, action_block) |
argmax of final probs (int64) |
probs |
[(n_envs, horizon, action_block, K)] |
final categorical distribution |
costs |
list[float] of length n_envs |
mean elite cost on the last iteration |
callbacks |
dict[str, list[list[Any]]] |
per-callback history (if any) |
Choosing between PGDSolver and CategoricalCEMSolver
Both target Discrete(K) action spaces.
PGDSolverdoes projected gradient descent on simplex-valued action variables. Requires a differentiablemodel.get_costand benefits from smooth cost landscapes.CategoricalCEMSolveris gradient-free. Use when the cost is non-differentiable (discrete simulators, ranking losses, learned classifiers used as oracles) or when PGD gets stuck in local minima.