[ Base Class ]
Solver
configure
solve
Solve the planning optimization problem to find optimal actions.
Parameters:
-
info_dict(dict) –Dictionary containing environment state information.
-
init_action(Tensor | None, default:None) –Optional initial action sequence to warm-start the solver.
Returns:
-
dict–Dictionary containing optimized actions and other solver-specific info.
[ Implementations ]
CEMSolver
CEMSolver(
model: Costable,
batch_size: int = 1,
num_samples: int = 300,
var_scale: float = 1,
n_steps: int = 30,
topk: int = 30,
device: str | device = 'cpu',
seed: int = 1234,
)
Cross Entropy Method solver for action optimization.
Parameters:
-
model(Costable) –World model implementing the Costable protocol.
-
batch_size(int, default:1) –Number of environments to process in parallel.
-
num_samples(int, default:300) –Number of action candidates to sample per iteration.
-
var_scale(float, default:1) –Initial variance scale for the action distribution.
-
n_steps(int, default:30) –Number of CEM iterations.
-
topk(int, default:30) –Number of elite samples to keep for distribution update.
-
device(str | device, default:'cpu') –Device for tensor computations.
-
seed(int, default:1234) –Random seed for reproducibility.
configure
Configure the solver with environment specifications.
solve
Solve the planning problem using Cross Entropy Method.
ICEMSolver
ICEMSolver(
model: Costable,
batch_size: int = 1,
num_samples: int = 300,
var_scale: float = 1,
n_steps: int = 30,
topk: int = 30,
noise_beta: float = 2.0,
alpha: float = 0.1,
n_elite_keep: int = 5,
return_mean: bool = True,
device: str | device = 'cpu',
seed: int = 1234,
)
Improved Cross Entropy Method (iCEM) solver with colored noise and elite retention. iCEM improves the sample efficiency over standard CEM and was introduced by [1] for real-time planning.
Parameters:
-
model(Costable) –World model implementing the Costable protocol.
-
batch_size(int, default:1) –Number of environments to process in parallel.
-
num_samples(int, default:300) –Number of action candidates to sample per iteration.
-
var_scale(float, default:1) –Initial variance scale for the action distribution.
-
n_steps(int, default:30) –Number of CEM iterations.
-
topk(int, default:30) –Number of elite samples to keep for distribution update.
-
noise_beta(float, default:2.0) –Colored noise exponent. 0 = white (standard CEM), >0 = more low-frequency noise.
-
alpha(float, default:0.1) –Momentum for mean/std EMA update.
-
n_elite_keep(int, default:5) –Number of elites carried from previous iteration.
-
return_mean(bool, default:True) –If False, return best single trajectory instead of mean.
-
device(str | device, default:'cpu') –Device for tensor computations.
-
seed(int, default:1234) –Random seed for reproducibility.
[1] C. Pinneri, S. Sawant, S. Blaes, J. Achterhold, J. Stueckler, M. Rolinek and G, Martius, Georg. "Sample-efficient Cross-Entropy Method for Real-time Planning". Conference on Robot Learning, 2020.
configure
Configure the solver with environment specifications.
solve
Solve the planning problem using improved Cross Entropy Method.
MPPISolver
MPPISolver(
model: Costable,
batch_size: int = 1,
num_samples: int = 300,
var_scale: float = 1.0,
n_steps: int = 30,
topk: int = 30,
temperature: float = 0.5,
device: str | device = 'cpu',
seed: int = 1234,
)
Model Predictive Path Integral solver for action optimization.
Parameters:
-
model(Costable) –World model implementing the Costable protocol.
-
batch_size(int, default:1) –Number of environments to process in parallel.
-
num_samples(int, default:300) –Number of action candidates to sample per iteration.
-
var_scale(float, default:1.0) –Initial variance scale for action noise.
-
n_steps(int, default:30) –Number of MPPI iterations.
-
topk(int, default:30) –Number of elite samples for weighted averaging.
-
temperature(float, default:0.5) –Temperature parameter for softmax weighting.
-
device(str | device, default:'cpu') –Device for tensor computations.
-
seed(int, default:1234) –Random seed for reproducibility.
configure
Configure the solver with environment specifications.
solve
Solve the planning problem using MPPI.
GradientSolver
GradientSolver(
model: Costable,
n_steps: int,
batch_size: int | None = None,
var_scale: float = 1,
num_samples: int = 1,
action_noise: float = 0.0,
device: str | device = 'cpu',
seed: int = 1234,
optimizer_cls: type[Optimizer] = SGD,
optimizer_kwargs: dict | None = None,
)
Bases: Module
Gradient-based solver using backpropagation through the world model.
Parameters:
-
model(Costable) –World model implementing the Costable protocol.
-
n_steps(int) –Number of gradient descent iterations.
-
batch_size(int | None, default:None) –Number of environments to process in parallel.
-
var_scale(float, default:1) –Initial variance scale for action perturbations.
-
num_samples(int, default:1) –Number of action samples to optimize in parallel.
-
action_noise(float, default:0.0) –Noise added to actions during optimization.
-
device(str | device, default:'cpu') –Device for tensor computations.
-
seed(int, default:1234) –Random seed for reproducibility.
-
optimizer_cls(type[Optimizer], default:SGD) –PyTorch optimizer class to use.
-
optimizer_kwargs(dict | None, default:None) –Keyword arguments for the optimizer.
configure
Configure the solver with environment specifications.
solve
Solve the planning problem using gradient descent.
PGDSolver
PGDSolver(
model: Costable,
n_steps: int,
batch_size: int | None = None,
var_scale: float = 1,
num_samples: int = 1,
action_noise: float = 0.0,
device: str | device = 'cpu',
seed: int = 1234,
)
Bases: Module
Projected Gradient Descent solver for discrete action optimization.
Parameters:
-
model(Costable) –World model implementing the Costable protocol.
-
n_steps(int) –Number of gradient descent iterations.
-
batch_size(int | None, default:None) –Number of environments to process in parallel.
-
var_scale(float, default:1) –Initial variance scale for action perturbations.
-
num_samples(int, default:1) –Number of action samples to optimize in parallel.
-
action_noise(float, default:0.0) –Noise added to actions during optimization.
-
device(str | device, default:'cpu') –Device for tensor computations.
-
seed(int, default:1234) –Random seed for reproducibility.
configure
Configure the solver with environment specifications.
solve
Solve the planning problem using projected gradient descent.
LagrangianSolver
LagrangianSolver(
model: Costable,
n_steps: int,
n_outer_steps: int = 5,
batch_size: int | None = None,
num_samples: int = 1,
var_scale: float = 1.0,
action_noise: float = 0.0,
rho_init: float = 1.0,
rho_max: float = 10000.0,
rho_scale: float = 2.0,
persist_multipliers: bool = True,
device: str | device = 'cpu',
seed: int = 1234,
optimizer_cls: type[Optimizer] = Adam,
optimizer_kwargs: dict | None = None,
)
Bases: Module
Lagrangian solver for stable world model.
get_cost returns the cost tensor (B, S). If the model also implements get_constraints, it should return the constraint violations (B, S, C), where C is the number of constraints. The constraint_cost should represent the cost of violating the constraints, where the constraint is satisfied when constraint_cost <= 0. The Lagrangian solver will optimize the following objective:
L = cost + sum_{i=1}^C lambda_i * constraint_cost_i + sum_{i=1}^C rho_i * max(0, constraint_cost_i)^2
If you want to use equality constraint, you can convert it to two inequality constraints. For example, if you want to enforce constraint_cost_i == 0, you can add two constraints: constraint_cost_i <= 0 and -constraint_cost_i <= 0.
Parameters:
-
model(Costable) –World model implementing the Costable protocol. Its get_cost() returns a plain cost tensor (B, S). If it also has get_constraints(), that method returns constraints of shape (B, S, C).
-
n_steps(int) –Number of gradient descent steps per outer iteration.
-
n_outer_steps(int, default:5) –Number of dual ascent (outer) iterations.
-
batch_size(int | None, default:None) –Number of environments to process in parallel.
-
num_samples(int, default:1) –Number of action samples to optimize in parallel.
-
var_scale(float, default:1.0) –Initial variance scale for action perturbations.
-
action_noise(float, default:0.0) –Noise added to actions during optimization.
-
rho_init(float, default:1.0) –Initial penalty coefficient for the quadratic constraint term.
-
rho_max(float, default:10000.0) –Maximum value of the penalty coefficient.
-
rho_scale(float, default:2.0) –Multiplicative growth factor for rho after each outer step.
-
persist_multipliers(bool, default:True) –Whether to warm-start Lagrange multipliers across solve() calls.
-
device(str | device, default:'cpu') –Device for tensor computations.
-
seed(int, default:1234) –Random seed for reproducibility.
-
optimizer_cls(type[Optimizer], default:Adam) –PyTorch optimizer class to use.
-
optimizer_kwargs(dict | None, default:None) –Keyword arguments for the optimizer.
configure
Configure the solver with environment specifications.
solve
Solve the planning problem using augmented Lagrangian gradient descent.
[ Example: Constrained Planning with LagrangianSolver ]
The LagrangianSolver extends gradient-based planning to handle inequality
constraints of the form g(a) ≤ 0. It uses the augmented Lagrangian method:
dual variables (λ) are maintained per environment and updated via dual ascent
after each inner optimisation loop, while a quadratic penalty term (controlled
by rho) enforces feasibility.
import dataclasses
import torch
import gymnasium as gym
import numpy as np
from stable_worldmodel.solver import LagrangianSolver
from stable_worldmodel.policy import PlanConfig
# ── 1. Define a world model with cost and optional constraints ──────────────
class MyModel(torch.nn.Module):
"""Minimal example: cost is MSE to a goal; two inequality constraints."""
def get_cost(self, info_dict, action_candidates):
# action_candidates: (B, S, H, D)
# returns: (B, S)
goal = torch.zeros(action_candidates.shape[-1])
return (action_candidates.mean(dim=2) - goal).pow(2).mean(dim=-1)
def get_constraints(self, info_dict, action_candidates):
# returns: (B, S, C) — violated when > 0
# g0: action L2 norm <= 1
g0 = action_candidates.norm(dim=-1).mean(dim=2) - 1.0
# g1: first action dimension <= 0.5
g1 = action_candidates[..., 0].mean(dim=2) - 0.5
return torch.stack([g0, g1], dim=-1)
# ── 2. Build and configure the solver ──────────────────────────────────────
model = MyModel()
solver = LagrangianSolver(
model=model,
n_steps=30, # inner gradient steps per outer iteration
n_outer_steps=10, # dual-ascent (outer) iterations
num_samples=8, # parallel action candidates per env
rho_init=1.0, # initial quadratic penalty coefficient
rho_scale=2.0, # rho doubles each outer step
rho_max=1e4,
persist_multipliers=True, # warm-start λ across planning calls
optimizer_kwargs={"lr": 0.05},
)
action_space = gym.spaces.Box(low=-np.inf, high=np.inf,
shape=(1, 4), dtype=np.float32)
config = PlanConfig(horizon=10, receding_horizon=1, action_block=1)
solver.configure(action_space=action_space, n_envs=2, config=config)
# ── 3. Solve ────────────────────────────────────────────────────────────────
info_dict = {"obs": torch.zeros(2, 4)} # current env observations
out = solver.solve(info_dict)
print(out["actions"].shape) # (2, 10, 4) — best action per env
print(out["lambdas"]) # (2, 2) — dual variables
print(out["constraint_violation"]) # mean ReLU(g) across samples
# ── 4. Receding-horizon planning (warm start) ───────────────────────────────
# Execute the first step, shift the plan, re-plan
executed_steps = 1
remaining = out["actions"][:, executed_steps:, :] # (2, 9, 4)
out2 = solver.solve(info_dict, init_action=remaining)
Key parameters
| Parameter | Default | Description |
|---|---|---|
n_steps |
— | Inner gradient steps per outer iteration |
n_outer_steps |
5 |
Dual-ascent iterations |
rho_init |
1.0 |
Initial quadratic penalty weight |
rho_scale |
2.0 |
Multiplicative growth for rho each outer step |
rho_max |
1e4 |
Upper bound on rho |
persist_multipliers |
True |
Keep λ across solve() calls (warm start) |
num_samples |
1 |
Parallel candidate trajectories per environment |
action_noise |
0.0 |
Gaussian noise injected each inner step |
Constraint protocol
Your model must implement get_constraints(info_dict, action_candidates) -> Tensor
returning shape (B, S, C). A constraint is satisfied when its value is ≤ 0.
To enforce an equality h(a) = 0, add two constraints: h(a) ≤ 0 and
-h(a) ≤ 0.