Model-based planning solvers for action optimization

[ Base Class ]

Solver

Bases: Protocol

Protocol for model-based planning solvers.

configure

configure(
    *, action_space: Space, n_envs: int, config: Any
) -> None

Configure the solver with environment and planning specifications.

Parameters:

  • action_space (Space) –

    The action space of the environment.

  • n_envs (int) –

    Number of parallel environments.

  • config (Any) –

    Planning configuration object.

solve

solve(
    info_dict: dict, init_action: Tensor | None = None
) -> dict

Solve the planning optimization problem to find optimal actions.

Parameters:

  • info_dict (dict) –

    Dictionary containing environment state information.

  • init_action (Tensor | None, default: None ) –

    Optional initial action sequence to warm-start the solver.

Returns:

  • dict

    Dictionary containing optimized actions and other solver-specific info.

action_dim property

action_dim: int

Flattened action dimension including action_block grouping.

n_envs property

n_envs: int

Number of parallel environments being planned for.

horizon property

horizon: int

Planning horizon length in timesteps.

[ Implementations ]

CEMSolver

CEMSolver(
    model: Costable,
    batch_size: int = 1,
    num_samples: int = 300,
    var_scale: float = 1,
    n_steps: int = 30,
    topk: int = 30,
    device: str | device = 'cpu',
    seed: int = 1234,
    callbacks: list[Callback] | None = None,
)

Cross Entropy Method solver for action optimization.

Parameters:

  • model (Costable) –

    World model implementing the Costable protocol.

  • batch_size (int, default: 1 ) –

    Number of environments to process in parallel.

  • num_samples (int, default: 300 ) –

    Number of action candidates to sample per iteration.

  • var_scale (float, default: 1 ) –

    Initial variance scale for the action distribution.

  • n_steps (int, default: 30 ) –

    Number of CEM iterations.

  • topk (int, default: 30 ) –

    Number of elite samples to keep for distribution update.

  • device (str | device, default: 'cpu' ) –

    Device for tensor computations.

  • seed (int, default: 1234 ) –

    Random seed for reproducibility.

configure

configure(
    *, action_space: Space, n_envs: int, config: Any
) -> None

Configure the solver with environment specifications.

solve

solve(
    info_dict: dict, init_action: Tensor | None = None
) -> dict

Solve the planning problem using Cross Entropy Method.

ICEMSolver

ICEMSolver(
    model: Costable,
    batch_size: int = 1,
    num_samples: int = 300,
    var_scale: float = 1,
    n_steps: int = 30,
    topk: int = 30,
    noise_beta: float = 2.0,
    alpha: float = 0.1,
    n_elite_keep: int = 5,
    return_mean: bool = True,
    device: str | device = 'cpu',
    seed: int = 1234,
    callbacks: list[Callback] | None = None,
)

Improved Cross Entropy Method (iCEM) solver with colored noise and elite retention. iCEM improves the sample efficiency over standard CEM and was introduced by [1] for real-time planning.

Parameters:

  • model (Costable) –

    World model implementing the Costable protocol.

  • batch_size (int, default: 1 ) –

    Number of environments to process in parallel.

  • num_samples (int, default: 300 ) –

    Number of action candidates to sample per iteration.

  • var_scale (float, default: 1 ) –

    Initial variance scale for the action distribution.

  • n_steps (int, default: 30 ) –

    Number of CEM iterations.

  • topk (int, default: 30 ) –

    Number of elite samples to keep for distribution update.

  • noise_beta (float, default: 2.0 ) –

    Colored noise exponent. 0 = white (standard CEM), >0 = more low-frequency noise.

  • alpha (float, default: 0.1 ) –

    Momentum for mean/std EMA update.

  • n_elite_keep (int, default: 5 ) –

    Number of elites carried from previous iteration.

  • return_mean (bool, default: True ) –

    If False, return best single trajectory instead of mean.

  • device (str | device, default: 'cpu' ) –

    Device for tensor computations.

  • seed (int, default: 1234 ) –

    Random seed for reproducibility.

[1] C. Pinneri, S. Sawant, S. Blaes, J. Achterhold, J. Stueckler, M. Rolinek and G, Martius, Georg. "Sample-efficient Cross-Entropy Method for Real-time Planning". Conference on Robot Learning, 2020.

configure

configure(
    *, action_space: Space, n_envs: int, config: Any
) -> None

Configure the solver with environment specifications.

solve

solve(
    info_dict: dict, init_action: Tensor | None = None
) -> dict

Solve the planning problem using improved Cross Entropy Method.

MPPISolver

MPPISolver(
    model: Costable,
    batch_size: int = 1,
    num_samples: int = 300,
    var_scale: float = 1.0,
    n_steps: int = 30,
    topk: int = 30,
    temperature: float = 0.5,
    device: str | device = 'cpu',
    seed: int = 1234,
)

Model Predictive Path Integral solver for action optimization.

Parameters:

  • model (Costable) –

    World model implementing the Costable protocol.

  • batch_size (int, default: 1 ) –

    Number of environments to process in parallel.

  • num_samples (int, default: 300 ) –

    Number of action candidates to sample per iteration.

  • var_scale (float, default: 1.0 ) –

    Initial variance scale for action noise.

  • n_steps (int, default: 30 ) –

    Number of MPPI iterations.

  • topk (int, default: 30 ) –

    Number of elite samples for weighted averaging.

  • temperature (float, default: 0.5 ) –

    Temperature parameter for softmax weighting.

  • device (str | device, default: 'cpu' ) –

    Device for tensor computations.

  • seed (int, default: 1234 ) –

    Random seed for reproducibility.

configure

configure(
    *, action_space: Space, n_envs: int, config: Any
) -> None

Configure the solver with environment specifications.

solve

solve(
    info_dict: dict, init_action: Tensor | None = None
) -> dict

Solve the planning problem using MPPI.

PredictiveSamplingSolver

PredictiveSamplingSolver(
    model: Costable,
    batch_size: int = 1,
    num_samples: int = 300,
    noise_scale: float = 1.0,
    device: str | device = 'cpu',
    seed: int = 1234,
)

Predictive Sampling solver for action optimization.

Parameters:

  • model (Costable) –

    World model implementing the Costable protocol.

  • batch_size (int, default: 1 ) –

    Number of environments to process in parallel.

  • num_samples (int, default: 300 ) –

    Number of action candidates to sample.

  • noise_scale (float, default: 1.0 ) –

    Standard deviation of additive Gaussian noise.

  • device (str | device, default: 'cpu' ) –

    Device for tensor computations.

  • seed (int, default: 1234 ) –

    Random seed for reproducibility.

configure

configure(
    *, action_space: Space, n_envs: int, config: Any
) -> None

Configure the solver with environment specifications.

solve

solve(
    info_dict: dict, init_action: Tensor | None = None
) -> dict

Solve the planning problem using Predictive Sampling.

GradientSolver

GradientSolver(
    model: Costable,
    n_steps: int,
    batch_size: int | None = None,
    var_scale: float = 1,
    num_samples: int = 1,
    action_noise: float = 0.0,
    device: str | device = 'cpu',
    seed: int = 1234,
    optimizer_cls: type[Optimizer] = SGD,
    optimizer_kwargs: dict | None = None,
    grad_clip: float | None = None,
    callbacks: list[Callback] | None = None,
)

Bases: Module

Gradient-based solver using backpropagation through the world model.

Parameters:

  • model (Costable) –

    World model implementing the Costable protocol.

  • n_steps (int) –

    Number of gradient descent iterations.

  • batch_size (int | None, default: None ) –

    Number of environments to process in parallel.

  • var_scale (float, default: 1 ) –

    Initial variance scale for action perturbations.

  • num_samples (int, default: 1 ) –

    Number of action samples to optimize in parallel.

  • action_noise (float, default: 0.0 ) –

    Noise added to actions during optimization.

  • device (str | device, default: 'cpu' ) –

    Device for tensor computations.

  • seed (int, default: 1234 ) –

    Random seed for reproducibility.

  • optimizer_cls (type[Optimizer], default: SGD ) –

    PyTorch optimizer class to use.

  • optimizer_kwargs (dict | None, default: None ) –

    Keyword arguments for the optimizer.

configure

configure(
    *, action_space: Space, n_envs: int, config: Any
) -> None

Configure the solver with environment specifications.

solve

solve(
    info_dict: dict, init_action: Tensor | None = None
) -> dict

Solve the planning problem using gradient descent.

PGDSolver

PGDSolver(
    model: Costable,
    n_steps: int,
    batch_size: int | None = None,
    var_scale: float = 1,
    num_samples: int = 1,
    action_noise: float = 0.0,
    device: str | device = 'cpu',
    seed: int = 1234,
)

Bases: Module

Projected Gradient Descent solver for discrete action optimization.

Parameters:

  • model (Costable) –

    World model implementing the Costable protocol.

  • n_steps (int) –

    Number of gradient descent iterations.

  • batch_size (int | None, default: None ) –

    Number of environments to process in parallel.

  • var_scale (float, default: 1 ) –

    Initial variance scale for action perturbations.

  • num_samples (int, default: 1 ) –

    Number of action samples to optimize in parallel.

  • action_noise (float, default: 0.0 ) –

    Noise added to actions during optimization.

  • device (str | device, default: 'cpu' ) –

    Device for tensor computations.

  • seed (int, default: 1234 ) –

    Random seed for reproducibility.

configure

configure(
    *, action_space: Space, n_envs: int, config: Any
) -> None

Configure the solver with environment specifications.

solve

solve(
    info_dict: dict,
    init_action: Tensor | None = None,
    from_scalar: bool = False,
) -> dict

Solve the planning problem using projected gradient descent.

CategoricalCEMSolver

CategoricalCEMSolver(
    model: Costable,
    batch_size: int = 1,
    num_samples: int = 300,
    n_steps: int = 30,
    topk: int = 30,
    smoothing: float = 0.0,
    alpha: float = 0.0,
    device: str | device = 'cpu',
    seed: int = 1234,
    callbacks: list[Callback] | None = None,
)

Cross Entropy Method solver for discrete action optimization.

Maintains a per-timestep categorical distribution over discrete actions, samples candidate trajectories via Gumbel-max, and refits the distribution from the top-K elites' empirical frequencies.

Parameters:

  • model (Costable) –

    World model implementing the Costable protocol.

  • batch_size (int, default: 1 ) –

    Number of environments to process in parallel.

  • num_samples (int, default: 300 ) –

    Number of action candidates to sample per iteration.

  • n_steps (int, default: 30 ) –

    Number of CEM iterations.

  • topk (int, default: 30 ) –

    Number of elite samples to keep for distribution update.

  • smoothing (float, default: 0.0 ) –

    Laplace smoothing added to refit probs to avoid collapse.

  • alpha (float, default: 0.0 ) –

    Momentum for probs EMA update (0 = full overwrite).

  • device (str | device, default: 'cpu' ) –

    Device for tensor computations.

  • seed (int, default: 1234 ) –

    Random seed for reproducibility.

  • callbacks (list[Callback] | None, default: None ) –

    Optional list of callbacks.

configure

configure(
    *, action_space: Space, n_envs: int, config: Any
) -> None

Configure the solver with environment specifications.

solve

solve(info_dict: dict, init_action: Any = None) -> dict

Solve the planning problem using Categorical CEM.

init_action is accepted for API parity with other solvers but is ignored; probs are always initialized uniform.

LagrangianSolver

LagrangianSolver(
    model: Costable,
    n_steps: int,
    n_outer_steps: int = 5,
    batch_size: int | None = None,
    num_samples: int = 1,
    var_scale: float = 1.0,
    action_noise: float = 0.0,
    rho_init: float = 1.0,
    rho_max: float = 10000.0,
    rho_scale: float = 2.0,
    persist_multipliers: bool = True,
    device: str | device = 'cpu',
    seed: int = 1234,
    optimizer_cls: type[Optimizer] = Adam,
    optimizer_kwargs: dict | None = None,
)

Bases: Module

Lagrangian solver for stable world model.

get_cost returns the cost tensor (B, S). If the model also implements get_constraints, it should return the constraint violations (B, S, C), where C is the number of constraints. The constraint_cost should represent the cost of violating the constraints, where the constraint is satisfied when constraint_cost <= 0. The Lagrangian solver will optimize the following objective:

L = cost + sum_{i=1}^C lambda_i * constraint_cost_i + sum_{i=1}^C rho_i * max(0, constraint_cost_i)^2

If you want to use equality constraint, you can convert it to two inequality constraints. For example, if you want to enforce constraint_cost_i == 0, you can add two constraints: constraint_cost_i <= 0 and -constraint_cost_i <= 0.

Parameters:

  • model (Costable) –

    World model implementing the Costable protocol. Its get_cost() returns a plain cost tensor (B, S). If it also has get_constraints(), that method returns constraints of shape (B, S, C).

  • n_steps (int) –

    Number of gradient descent steps per outer iteration.

  • n_outer_steps (int, default: 5 ) –

    Number of dual ascent (outer) iterations.

  • batch_size (int | None, default: None ) –

    Number of environments to process in parallel.

  • num_samples (int, default: 1 ) –

    Number of action samples to optimize in parallel.

  • var_scale (float, default: 1.0 ) –

    Initial variance scale for action perturbations.

  • action_noise (float, default: 0.0 ) –

    Noise added to actions during optimization.

  • rho_init (float, default: 1.0 ) –

    Initial penalty coefficient for the quadratic constraint term.

  • rho_max (float, default: 10000.0 ) –

    Maximum value of the penalty coefficient.

  • rho_scale (float, default: 2.0 ) –

    Multiplicative growth factor for rho after each outer step.

  • persist_multipliers (bool, default: True ) –

    Whether to warm-start Lagrange multipliers across solve() calls.

  • device (str | device, default: 'cpu' ) –

    Device for tensor computations.

  • seed (int, default: 1234 ) –

    Random seed for reproducibility.

  • optimizer_cls (type[Optimizer], default: Adam ) –

    PyTorch optimizer class to use.

  • optimizer_kwargs (dict | None, default: None ) –

    Keyword arguments for the optimizer.

configure

configure(
    *, action_space: Space, n_envs: int, config: Any
) -> None

Configure the solver with environment specifications.

solve

solve(
    info_dict: dict, init_action: Tensor | None = None
) -> dict

Solve the planning problem using augmented Lagrangian gradient descent.

[ Callbacks ]

Solvers accept a callbacks=[...] list of Callback objects. Each callback fires once per inner-loop step and accumulates a per-batch buffer; final histories are returned in outputs['callbacks'], keyed by cb.output_key (defaults to the class name).

from stable_worldmodel.solver import GradientSolver
from stable_worldmodel.solver.callbacks import (
    BestCostRecorder, GradNormRecorder, ActionNormRecorder,
)

solver = GradientSolver(
    model=model, n_steps=20, num_samples=8,
    callbacks=[
        BestCostRecorder(),                # mean over envs (default)
        GradNormRecorder(reduction='none'), # one entry per env
        ActionNormRecorder(reduction='sum'),
    ],
)
solver.configure(action_space=action_space, n_envs=4, config=config)
out = solver.solve(info_dict)

# out['callbacks']['BestCostRecorder']  -> list[list[float]]   (batches x steps)
# out['callbacks']['GradNormRecorder']  -> list[list[list[float]]]

Reduction modes

Every callback accepts reduction ∈ {'mean', 'sum', 'none'}. Reduction is applied across the env axis only; within-sample reductions (e.g. min over samples for BestCostRecorder) are intrinsic to each metric.

Mode Output per step
'mean' scalar (default)
'sum' scalar
'none' list[float] — one value per env in batch

Available callbacks

Callback Solver(s) Records
BestCostRecorder any min cost over samples
MeanCostRecorder any mean cost over samples
GradNormRecorder GD L2 norm of action gradient (optional per_step for per-horizon-step values)
ActionNormRecorder GD L2 norm of action tensor
EliteCostRecorder CEM, iCEM dict of elite cost stats (mean/min/max)
VarNormRecorder CEM, iCEM mean variance of action distribution
MeanShiftRecorder CEM, iCEM L2 distance between consecutive means
EliteSpreadRecorder CEM, iCEM within-elite std (top-k diversity)

Writing a custom callback

Subclass Callback and implement compute(**state). Pull the tensors you need from state and call self._reduce(per_env_tensor) to honour the reduction mode.

from stable_worldmodel.solver.callbacks import Callback

class CostRangeRecorder(Callback):
    """Records per-env (max - min) cost across the sample population."""

    def compute(self, **state):
        costs = state['costs'].detach()           # (B, N)
        per_env = costs.max(dim=1).values - costs.min(dim=1).values
        return self._reduce(per_env)

State keys passed by each solver:

  • GD: step, params, cost, costs
  • CEM: step, candidates, costs, topk_vals, topk_inds, topk_candidates, mean, var, prev_mean, prev_var
  • iCEM: same as CEM plus action_low, action_high
  • CategoricalCEM: step, candidates, costs, topk_vals, topk_inds, topk_candidates, probs, prev_probs

Callback

Callback(reduction: Reduction = 'mean')

Base class for solver iteration callbacks.

Subclasses compute a per-env metric and call self._reduce(...) to apply the configured reduction across envs. history is list[list[Any]] (batches x steps), matching the shape of outputs['cost'] in the gradient solver.

Methods:

Attributes:

output_key property

output_key: str

reset

reset() -> None

start_batch

start_batch() -> None

end_solve

end_solve() -> None

compute

compute(**state: Any) -> Any

BestCostRecorder

BestCostRecorder(reduction: Reduction = 'mean')

Bases: Callback

Per-step minimum cost over the sample population (per env).

MeanCostRecorder

MeanCostRecorder(reduction: Reduction = 'mean')

Bases: Callback

Per-step mean cost over the sample population (per env).

GradNormRecorder

GradNormRecorder(
    reduction: Reduction = 'mean', per_step: bool = False
)

Bases: Callback

Per-step L2 norm of the action gradient (per env, mean over samples).

With per_step=True, returns a list of length H with one grad norm per horizon step instead of reducing over the horizon dim.

ActionNormRecorder

ActionNormRecorder(reduction: Reduction = 'mean')

Bases: Callback

Per-step L2 norm of the action tensor (per env, mean over samples).

EliteCostRecorder

EliteCostRecorder(reduction: Reduction = 'mean')

Bases: Callback

Per-step elite cost stats (mean, min, max), per env.

VarNormRecorder

VarNormRecorder(reduction: Reduction = 'mean')

Bases: Callback

Per-step mean variance of the action distribution (per env).

MeanShiftRecorder

MeanShiftRecorder(reduction: Reduction = 'mean')

Bases: Callback

Per-step L2 distance between consecutive distribution means (per env).

EliteSpreadRecorder

EliteSpreadRecorder(reduction: Reduction = 'mean')

Bases: Callback

Per-step within-elite std (diversity of the top-k elites, per env).

[ Example: Constrained Planning with LagrangianSolver ]

The LagrangianSolver extends gradient-based planning to handle inequality constraints of the form g(a) ≤ 0. It uses the augmented Lagrangian method: dual variables (λ) are maintained per environment and updated via dual ascent after each inner optimisation loop, while a quadratic penalty term (controlled by rho) enforces feasibility.

import dataclasses
import torch
import gymnasium as gym
import numpy as np
from stable_worldmodel.solver import LagrangianSolver
from stable_worldmodel.policy import PlanConfig


# ── 1. Define a world model with cost and optional constraints ──────────────

class MyModel(torch.nn.Module):
    """Minimal example: cost is MSE to a goal; two inequality constraints."""

    def get_cost(self, info_dict, action_candidates):
        # action_candidates: (B, S, H, D)
        # returns:           (B, S)
        goal = torch.zeros(action_candidates.shape[-1])
        return (action_candidates.mean(dim=2) - goal).pow(2).mean(dim=-1)

    def get_constraints(self, info_dict, action_candidates):
        # returns: (B, S, C)  — violated when > 0
        # g0: action L2 norm <= 1
        g0 = action_candidates.norm(dim=-1).mean(dim=2) - 1.0
        # g1: first action dimension <= 0.5
        g1 = action_candidates[..., 0].mean(dim=2) - 0.5
        return torch.stack([g0, g1], dim=-1)


# ── 2. Build and configure the solver ──────────────────────────────────────

model = MyModel()

solver = LagrangianSolver(
    model=model,
    n_steps=30,            # inner gradient steps per outer iteration
    n_outer_steps=10,      # dual-ascent (outer) iterations
    num_samples=8,         # parallel action candidates per env
    rho_init=1.0,          # initial quadratic penalty coefficient
    rho_scale=2.0,         # rho doubles each outer step
    rho_max=1e4,
    persist_multipliers=True,  # warm-start λ across planning calls
    optimizer_kwargs={"lr": 0.05},
)

action_space = gym.spaces.Box(low=-np.inf, high=np.inf,
                              shape=(1, 4), dtype=np.float32)
config = PlanConfig(horizon=10, receding_horizon=1, action_block=1)
solver.configure(action_space=action_space, n_envs=2, config=config)


# ── 3. Solve ────────────────────────────────────────────────────────────────

info_dict = {"obs": torch.zeros(2, 4)}  # current env observations
out = solver.solve(info_dict)

print(out["actions"].shape)        # (2, 10, 4)  — best action per env
print(out["lambdas"])              # (2, 2)       — dual variables
print(out["constraint_violation"]) # mean ReLU(g) across samples


# ── 4. Receding-horizon planning (warm start) ───────────────────────────────

# Execute the first step, shift the plan, re-plan
executed_steps = 1
remaining = out["actions"][:, executed_steps:, :]   # (2, 9, 4)
out2 = solver.solve(info_dict, init_action=remaining)

Key parameters

Parameter Default Description
n_steps Inner gradient steps per outer iteration
n_outer_steps 5 Dual-ascent iterations
rho_init 1.0 Initial quadratic penalty weight
rho_scale 2.0 Multiplicative growth for rho each outer step
rho_max 1e4 Upper bound on rho
persist_multipliers True Keep λ across solve() calls (warm start)
num_samples 1 Parallel candidate trajectories per environment
action_noise 0.0 Gaussian noise injected each inner step

Constraint protocol

Your model must implement get_constraints(info_dict, action_candidates) -> Tensor returning shape (B, S, C). A constraint is satisfied when its value is ≤ 0.

To enforce an equality h(a) = 0, add two constraints: h(a) ≤ 0 and -h(a) ≤ 0.

[ Example: Discrete Planning with CategoricalCEMSolver ]

CategoricalCEMSolver is the discrete-action analogue of CEMSolver. Instead of fitting a Gaussian per timestep, it maintains a categorical distribution over the Discrete(K) action space and refits it from the empirical frequencies of top-K elite trajectories. Sampling uses the Gumbel-max trick (seeded via the solver's torch.Generator) and candidates are passed to model.get_cost as one-hot tensors — the same layout used by PGDSolver, so discrete world models work unchanged.

import torch
import gymnasium as gym
from stable_worldmodel.solver import CategoricalCEMSolver
from stable_worldmodel.policy import PlanConfig


# ── 1. World model: cost defined over one-hot candidates ────────────────────

class DiscreteModel(torch.nn.Module):
    """Cost is minimized by selecting category 2 at every position."""

    def get_cost(self, info_dict, action_candidates):
        # action_candidates: (B, N, H, action_block * K) one-hot floats
        # returns:          (B, N)
        K = 4
        ab = action_candidates.shape[-1] // K
        c = action_candidates.reshape(*action_candidates.shape[:-1], ab, K)
        return -c[..., 2].sum(dim=(-1, -2))


# ── 2. Build and configure the solver ──────────────────────────────────────

solver = CategoricalCEMSolver(
    model=DiscreteModel(),
    n_steps=20,        # CEM iterations
    num_samples=128,   # candidates per iteration
    topk=16,           # elite count
    smoothing=0.01,    # Laplace floor — prevents premature collapse
    alpha=0.1,         # EMA momentum on probs (0 = full overwrite)
    seed=0,
)

action_space = gym.spaces.Discrete(4)
config = PlanConfig(horizon=8, receding_horizon=4, action_block=1)
solver.configure(action_space=action_space, n_envs=2, config=config)


# ── 3. Solve ───────────────────────────────────────────────────────────────

info_dict = {"obs": torch.zeros(2, 4)}
out = solver.solve(info_dict)

print(out["actions"].shape)     # (2, 8, 1)  — discrete indices, argmax of probs
print(out["probs"][0].shape)    # (2, 8, 1, 4) — final categorical distribution
print(out["costs"])             # mean elite cost per env

Key parameters

Parameter Default Description
num_samples 300 Candidate trajectories sampled per iteration
n_steps 30 CEM iterations
topk 30 Elite count for refit
smoothing 0.0 Laplace smoothing on refit probs (avoids collapse)
alpha 0.0 EMA momentum: probs ← α · prev + (1−α) · new
batch_size 1 Envs processed per outer batch

Output layout

Key Shape Meaning
actions (n_envs, horizon, action_block) argmax of final probs (int64)
probs [(n_envs, horizon, action_block, K)] final categorical distribution
costs list[float] of length n_envs mean elite cost on the last iteration
callbacks dict[str, list[list[Any]]] per-callback history (if any)

Choosing between PGDSolver and CategoricalCEMSolver

Both target Discrete(K) action spaces.

  • PGDSolver does projected gradient descent on simplex-valued action variables. Requires a differentiable model.get_cost and benefits from smooth cost landscapes.
  • CategoricalCEMSolver is gradient-free. Use when the cost is non-differentiable (discrete simulators, ranking losses, learned classifiers used as oracles) or when PGD gets stuck in local minima.