Model-based planning solvers for action optimization

[ Base Class ]

Solver

Bases: Protocol

Protocol for model-based planning solvers.

configure

configure(
    *, action_space: Space, n_envs: int, config: Any
) -> None

Configure the solver with environment and planning specifications.

Parameters:

  • action_space (Space) –

    The action space of the environment.

  • n_envs (int) –

    Number of parallel environments.

  • config (Any) –

    Planning configuration object.

solve

solve(
    info_dict: dict, init_action: Tensor | None = None
) -> dict

Solve the planning optimization problem to find optimal actions.

Parameters:

  • info_dict (dict) –

    Dictionary containing environment state information.

  • init_action (Tensor | None, default: None ) –

    Optional initial action sequence to warm-start the solver.

Returns:

  • dict

    Dictionary containing optimized actions and other solver-specific info.

action_dim property

action_dim: int

Flattened action dimension including action_block grouping.

n_envs property

n_envs: int

Number of parallel environments being planned for.

horizon property

horizon: int

Planning horizon length in timesteps.

[ Implementations ]

CEMSolver

CEMSolver(
    model: Costable,
    batch_size: int = 1,
    num_samples: int = 300,
    var_scale: float = 1,
    n_steps: int = 30,
    topk: int = 30,
    device: str | device = 'cpu',
    seed: int = 1234,
)

Cross Entropy Method solver for action optimization.

Parameters:

  • model (Costable) –

    World model implementing the Costable protocol.

  • batch_size (int, default: 1 ) –

    Number of environments to process in parallel.

  • num_samples (int, default: 300 ) –

    Number of action candidates to sample per iteration.

  • var_scale (float, default: 1 ) –

    Initial variance scale for the action distribution.

  • n_steps (int, default: 30 ) –

    Number of CEM iterations.

  • topk (int, default: 30 ) –

    Number of elite samples to keep for distribution update.

  • device (str | device, default: 'cpu' ) –

    Device for tensor computations.

  • seed (int, default: 1234 ) –

    Random seed for reproducibility.

configure

configure(
    *, action_space: Space, n_envs: int, config: Any
) -> None

Configure the solver with environment specifications.

solve

solve(
    info_dict: dict, init_action: Tensor | None = None
) -> dict

Solve the planning problem using Cross Entropy Method.

ICEMSolver

ICEMSolver(
    model: Costable,
    batch_size: int = 1,
    num_samples: int = 300,
    var_scale: float = 1,
    n_steps: int = 30,
    topk: int = 30,
    noise_beta: float = 2.0,
    alpha: float = 0.1,
    n_elite_keep: int = 5,
    return_mean: bool = True,
    device: str | device = 'cpu',
    seed: int = 1234,
)

Improved Cross Entropy Method (iCEM) solver with colored noise and elite retention. iCEM improves the sample efficiency over standard CEM and was introduced by [1] for real-time planning.

Parameters:

  • model (Costable) –

    World model implementing the Costable protocol.

  • batch_size (int, default: 1 ) –

    Number of environments to process in parallel.

  • num_samples (int, default: 300 ) –

    Number of action candidates to sample per iteration.

  • var_scale (float, default: 1 ) –

    Initial variance scale for the action distribution.

  • n_steps (int, default: 30 ) –

    Number of CEM iterations.

  • topk (int, default: 30 ) –

    Number of elite samples to keep for distribution update.

  • noise_beta (float, default: 2.0 ) –

    Colored noise exponent. 0 = white (standard CEM), >0 = more low-frequency noise.

  • alpha (float, default: 0.1 ) –

    Momentum for mean/std EMA update.

  • n_elite_keep (int, default: 5 ) –

    Number of elites carried from previous iteration.

  • return_mean (bool, default: True ) –

    If False, return best single trajectory instead of mean.

  • device (str | device, default: 'cpu' ) –

    Device for tensor computations.

  • seed (int, default: 1234 ) –

    Random seed for reproducibility.

[1] C. Pinneri, S. Sawant, S. Blaes, J. Achterhold, J. Stueckler, M. Rolinek and G, Martius, Georg. "Sample-efficient Cross-Entropy Method for Real-time Planning". Conference on Robot Learning, 2020.

configure

configure(
    *, action_space: Space, n_envs: int, config: Any
) -> None

Configure the solver with environment specifications.

solve

solve(
    info_dict: dict, init_action: Tensor | None = None
) -> dict

Solve the planning problem using improved Cross Entropy Method.

MPPISolver

MPPISolver(
    model: Costable,
    batch_size: int = 1,
    num_samples: int = 300,
    var_scale: float = 1.0,
    n_steps: int = 30,
    topk: int = 30,
    temperature: float = 0.5,
    device: str | device = 'cpu',
    seed: int = 1234,
)

Model Predictive Path Integral solver for action optimization.

Parameters:

  • model (Costable) –

    World model implementing the Costable protocol.

  • batch_size (int, default: 1 ) –

    Number of environments to process in parallel.

  • num_samples (int, default: 300 ) –

    Number of action candidates to sample per iteration.

  • var_scale (float, default: 1.0 ) –

    Initial variance scale for action noise.

  • n_steps (int, default: 30 ) –

    Number of MPPI iterations.

  • topk (int, default: 30 ) –

    Number of elite samples for weighted averaging.

  • temperature (float, default: 0.5 ) –

    Temperature parameter for softmax weighting.

  • device (str | device, default: 'cpu' ) –

    Device for tensor computations.

  • seed (int, default: 1234 ) –

    Random seed for reproducibility.

configure

configure(
    *, action_space: Space, n_envs: int, config: Any
) -> None

Configure the solver with environment specifications.

solve

solve(
    info_dict: dict, init_action: Tensor | None = None
) -> dict

Solve the planning problem using MPPI.

GradientSolver

GradientSolver(
    model: Costable,
    n_steps: int,
    batch_size: int | None = None,
    var_scale: float = 1,
    num_samples: int = 1,
    action_noise: float = 0.0,
    device: str | device = 'cpu',
    seed: int = 1234,
    optimizer_cls: type[Optimizer] = SGD,
    optimizer_kwargs: dict | None = None,
)

Bases: Module

Gradient-based solver using backpropagation through the world model.

Parameters:

  • model (Costable) –

    World model implementing the Costable protocol.

  • n_steps (int) –

    Number of gradient descent iterations.

  • batch_size (int | None, default: None ) –

    Number of environments to process in parallel.

  • var_scale (float, default: 1 ) –

    Initial variance scale for action perturbations.

  • num_samples (int, default: 1 ) –

    Number of action samples to optimize in parallel.

  • action_noise (float, default: 0.0 ) –

    Noise added to actions during optimization.

  • device (str | device, default: 'cpu' ) –

    Device for tensor computations.

  • seed (int, default: 1234 ) –

    Random seed for reproducibility.

  • optimizer_cls (type[Optimizer], default: SGD ) –

    PyTorch optimizer class to use.

  • optimizer_kwargs (dict | None, default: None ) –

    Keyword arguments for the optimizer.

configure

configure(
    *, action_space: Space, n_envs: int, config: Any
) -> None

Configure the solver with environment specifications.

solve

solve(
    info_dict: dict, init_action: Tensor | None = None
) -> dict

Solve the planning problem using gradient descent.

PGDSolver

PGDSolver(
    model: Costable,
    n_steps: int,
    batch_size: int | None = None,
    var_scale: float = 1,
    num_samples: int = 1,
    action_noise: float = 0.0,
    device: str | device = 'cpu',
    seed: int = 1234,
)

Bases: Module

Projected Gradient Descent solver for discrete action optimization.

Parameters:

  • model (Costable) –

    World model implementing the Costable protocol.

  • n_steps (int) –

    Number of gradient descent iterations.

  • batch_size (int | None, default: None ) –

    Number of environments to process in parallel.

  • var_scale (float, default: 1 ) –

    Initial variance scale for action perturbations.

  • num_samples (int, default: 1 ) –

    Number of action samples to optimize in parallel.

  • action_noise (float, default: 0.0 ) –

    Noise added to actions during optimization.

  • device (str | device, default: 'cpu' ) –

    Device for tensor computations.

  • seed (int, default: 1234 ) –

    Random seed for reproducibility.

configure

configure(
    *, action_space: Space, n_envs: int, config: Any
) -> None

Configure the solver with environment specifications.

solve

solve(
    info_dict: dict,
    init_action: Tensor | None = None,
    from_scalar: bool = False,
) -> dict

Solve the planning problem using projected gradient descent.

LagrangianSolver

LagrangianSolver(
    model: Costable,
    n_steps: int,
    n_outer_steps: int = 5,
    batch_size: int | None = None,
    num_samples: int = 1,
    var_scale: float = 1.0,
    action_noise: float = 0.0,
    rho_init: float = 1.0,
    rho_max: float = 10000.0,
    rho_scale: float = 2.0,
    persist_multipliers: bool = True,
    device: str | device = 'cpu',
    seed: int = 1234,
    optimizer_cls: type[Optimizer] = Adam,
    optimizer_kwargs: dict | None = None,
)

Bases: Module

Lagrangian solver for stable world model.

get_cost returns the cost tensor (B, S). If the model also implements get_constraints, it should return the constraint violations (B, S, C), where C is the number of constraints. The constraint_cost should represent the cost of violating the constraints, where the constraint is satisfied when constraint_cost <= 0. The Lagrangian solver will optimize the following objective:

L = cost + sum_{i=1}^C lambda_i * constraint_cost_i + sum_{i=1}^C rho_i * max(0, constraint_cost_i)^2

If you want to use equality constraint, you can convert it to two inequality constraints. For example, if you want to enforce constraint_cost_i == 0, you can add two constraints: constraint_cost_i <= 0 and -constraint_cost_i <= 0.

Parameters:

  • model (Costable) –

    World model implementing the Costable protocol. Its get_cost() returns a plain cost tensor (B, S). If it also has get_constraints(), that method returns constraints of shape (B, S, C).

  • n_steps (int) –

    Number of gradient descent steps per outer iteration.

  • n_outer_steps (int, default: 5 ) –

    Number of dual ascent (outer) iterations.

  • batch_size (int | None, default: None ) –

    Number of environments to process in parallel.

  • num_samples (int, default: 1 ) –

    Number of action samples to optimize in parallel.

  • var_scale (float, default: 1.0 ) –

    Initial variance scale for action perturbations.

  • action_noise (float, default: 0.0 ) –

    Noise added to actions during optimization.

  • rho_init (float, default: 1.0 ) –

    Initial penalty coefficient for the quadratic constraint term.

  • rho_max (float, default: 10000.0 ) –

    Maximum value of the penalty coefficient.

  • rho_scale (float, default: 2.0 ) –

    Multiplicative growth factor for rho after each outer step.

  • persist_multipliers (bool, default: True ) –

    Whether to warm-start Lagrange multipliers across solve() calls.

  • device (str | device, default: 'cpu' ) –

    Device for tensor computations.

  • seed (int, default: 1234 ) –

    Random seed for reproducibility.

  • optimizer_cls (type[Optimizer], default: Adam ) –

    PyTorch optimizer class to use.

  • optimizer_kwargs (dict | None, default: None ) –

    Keyword arguments for the optimizer.

configure

configure(
    *, action_space: Space, n_envs: int, config: Any
) -> None

Configure the solver with environment specifications.

solve

solve(
    info_dict: dict, init_action: Tensor | None = None
) -> dict

Solve the planning problem using augmented Lagrangian gradient descent.

[ Example: Constrained Planning with LagrangianSolver ]

The LagrangianSolver extends gradient-based planning to handle inequality constraints of the form g(a) ≤ 0. It uses the augmented Lagrangian method: dual variables (λ) are maintained per environment and updated via dual ascent after each inner optimisation loop, while a quadratic penalty term (controlled by rho) enforces feasibility.

import dataclasses
import torch
import gymnasium as gym
import numpy as np
from stable_worldmodel.solver import LagrangianSolver
from stable_worldmodel.policy import PlanConfig


# ── 1. Define a world model with cost and optional constraints ──────────────

class MyModel(torch.nn.Module):
    """Minimal example: cost is MSE to a goal; two inequality constraints."""

    def get_cost(self, info_dict, action_candidates):
        # action_candidates: (B, S, H, D)
        # returns:           (B, S)
        goal = torch.zeros(action_candidates.shape[-1])
        return (action_candidates.mean(dim=2) - goal).pow(2).mean(dim=-1)

    def get_constraints(self, info_dict, action_candidates):
        # returns: (B, S, C)  — violated when > 0
        # g0: action L2 norm <= 1
        g0 = action_candidates.norm(dim=-1).mean(dim=2) - 1.0
        # g1: first action dimension <= 0.5
        g1 = action_candidates[..., 0].mean(dim=2) - 0.5
        return torch.stack([g0, g1], dim=-1)


# ── 2. Build and configure the solver ──────────────────────────────────────

model = MyModel()

solver = LagrangianSolver(
    model=model,
    n_steps=30,            # inner gradient steps per outer iteration
    n_outer_steps=10,      # dual-ascent (outer) iterations
    num_samples=8,         # parallel action candidates per env
    rho_init=1.0,          # initial quadratic penalty coefficient
    rho_scale=2.0,         # rho doubles each outer step
    rho_max=1e4,
    persist_multipliers=True,  # warm-start λ across planning calls
    optimizer_kwargs={"lr": 0.05},
)

action_space = gym.spaces.Box(low=-np.inf, high=np.inf,
                              shape=(1, 4), dtype=np.float32)
config = PlanConfig(horizon=10, receding_horizon=1, action_block=1)
solver.configure(action_space=action_space, n_envs=2, config=config)


# ── 3. Solve ────────────────────────────────────────────────────────────────

info_dict = {"obs": torch.zeros(2, 4)}  # current env observations
out = solver.solve(info_dict)

print(out["actions"].shape)        # (2, 10, 4)  — best action per env
print(out["lambdas"])              # (2, 2)       — dual variables
print(out["constraint_violation"]) # mean ReLU(g) across samples


# ── 4. Receding-horizon planning (warm start) ───────────────────────────────

# Execute the first step, shift the plan, re-plan
executed_steps = 1
remaining = out["actions"][:, executed_steps:, :]   # (2, 9, 4)
out2 = solver.solve(info_dict, init_action=remaining)

Key parameters

Parameter Default Description
n_steps Inner gradient steps per outer iteration
n_outer_steps 5 Dual-ascent iterations
rho_init 1.0 Initial quadratic penalty weight
rho_scale 2.0 Multiplicative growth for rho each outer step
rho_max 1e4 Upper bound on rho
persist_multipliers True Keep λ across solve() calls (warm start)
num_samples 1 Parallel candidate trajectories per environment
action_noise 0.0 Gaussian noise injected each inner step

Constraint protocol

Your model must implement get_constraints(info_dict, action_candidates) -> Tensor returning shape (B, S, C). A constraint is satisfied when its value is ≤ 0.

To enforce an equality h(a) = 0, add two constraints: h(a) ≤ 0 and -h(a) ≤ 0.