Wrapper
Custom environment wrappers
MegaWrapper
MegaWrapper(
env: Env,
image_shape: tuple[int, int] = (84, 84),
pixels_transform: Callable[[Any], Any] | None = None,
goal_transform: Callable[[Any], Any] | None = None,
required_keys: Iterable[str] | None = None,
separate_goal: bool = True,
image_resample: str | int | None = None,
)
Bases: Wrapper
Combines multiple wrappers for comprehensive environment preprocessing.
Parameters:
-
env(Env) –The environment to wrap.
-
image_shape(tuple[int, int], default:(84, 84)) –Target (height, width) for all image processing.
-
pixels_transform(Callable[[Any], Any] | None, default:None) –Optional transform for rendered pixels.
-
goal_transform(Callable[[Any], Any] | None, default:None) –Optional transform for goal images.
-
required_keys(Iterable[str] | None, default:None) –Keys that must be present in info dict.
-
separate_goal(bool, default:True) –Whether to handle goal separately.
-
image_resample(str | int | None, default:None) –PIL resample mode used when resizing pixels and goal images. Accepts a PIL constant or a string in
{'nearest','bilinear','bicubic','lanczos','box','hamming'}. Defaults to bilinear. Use'nearest'for crisp pixel art.
EnsureInfoKeysWrapper
EnsureImageShape
EnsureGoalInfoWrapper
EverythingToInfoWrapper
EverythingToInfoWrapper(env: Env)
Bases: Wrapper
Moves all transition information into the info dict.
Parameters:
-
env(Env) –The environment to wrap.
AddPixelsWrapper
AddPixelsWrapper(
env: Env,
pixels_shape: tuple[int, int] = (84, 84),
torchvision_transform: Callable[[Any], Any]
| None = None,
resample: int | None = None,
)
Bases: Wrapper
Adds rendered environment pixels to info dict.
Parameters:
-
env(Env) –The environment to wrap.
-
pixels_shape(tuple[int, int], default:(84, 84)) –Target (height, width) for rendered pixels.
-
torchvision_transform(Callable[[Any], Any] | None, default:None) –Optional transform to apply to the pixels.
-
resample(int | None, default:None) –PIL resample filter (e.g.
Image.BILINEAR,Image.NEAREST). Defaults to BILINEAR.
ResizeGoalWrapper
ResizeGoalWrapper(
env: Env,
pixels_shape: tuple[int, int] = (84, 84),
torchvision_transform: Callable[[Any], Any]
| None = None,
resample: int | None = None,
)
Bases: Wrapper
Resizes goal images in info dict.
Parameters:
-
env(Env) –The environment to wrap.
-
pixels_shape(tuple[int, int], default:(84, 84)) –Target (height, width) for resizing goal images.
-
torchvision_transform(Callable[[Any], Any] | None, default:None) –Optional transform to apply to goal images.
-
resample(int | None, default:None) –PIL resample filter (e.g.
Image.BILINEAR,Image.NEAREST). Defaults to BILINEAR.
Visual Wrappers
Visual wrappers operate on rendered frames (the output of env.render() and any info["pixels*"] entries). They are useful for adding distractors at evaluation time or augmentations at training time.

ChromaKeyWrapper
ChromaKeyWrapper(env, key_color, media, tolerance=0.0)
Bases: Wrapper
Replace pixels matching a key color in rendered frames with an image or video background.
Works like a green-screen: pixels close to key_color (within tolerance) are swapped
out for the corresponding pixels of media. If media is a video, frames advance and
loop on each call to render.

NoiseWrapper
NoiseWrapper(env, std=10.0, seed=None)
Bases: _PixelTransform
Add Gaussian pixel noise with a step-dependent standard deviation.
std is either a float or a callable f(step) -> float (e.g. linear,
cosine, exponential, sinusoidal, or any user-provided function).
The wrapper increments an internal step counter on each env.step call and
passes the current count to the schedule before sampling noise.

ColorJitterWrapper
ColorJitterWrapper(
env,
brightness=0.2,
contrast=0.2,
saturation=0.2,
hue=0.05,
seed=None,
)
Bases: _PixelTransform
Random brightness, contrast, saturation, and hue shifts. Factors resample each reset.

BlurWrapper
BlurWrapper(env, kernel=5, sigma=0.0)
Bases: _PixelTransform
Gaussian blur with odd kernel size and standard deviation sigma (0 = derived from kernel).

OcclusionWrapper
OcclusionWrapper(
env, num_patches=1, size=(0.1, 0.3), color=0, seed=None
)
Bases: _PixelTransform
Cover the frame with num_patches rectangles of fractional size. Patches resample per reset.

MovingPatchWrapper
MovingPatchWrapper(
env,
num_patches=1,
size=(0.1, 0.2),
color=255,
speed=2.0,
seed=None,
)
Bases: _PixelTransform
Overlay num_patches solid-color rectangles that drift with their own velocity.
Each patch has an independent position and velocity sampled at reset. Positions advance
by their velocity once per env.step and reflect off the frame edges, so motion is
smooth and continuous (no teleporting). speed is the velocity magnitude in pixels
per step.

RandomShiftWrapper
RandomShiftWrapper(env, pad=4, seed=None)
Bases: _PixelTransform
DrQ-style random shift: replicate-pad by pad pixels then random crop back. Resamples each call.

CutoutWrapper
CutoutWrapper(
env, num=1, size=(0.1, 0.2), color=0, seed=None
)
Bases: _PixelTransform
Mask num random rectangles per frame with color. Resamples on every call.

RandomConvWrapper
RandomConvWrapper(env, kernel_size=3, seed=None)
Bases: _PixelTransform
Pass the frame through a randomly-initialized conv (3->3 channels). Weights resample per reset.

GrayscaleWrapper
GrayscaleWrapper(env, keep_channels=True)
Bases: _PixelTransform
Convert frame to grayscale. keep_channels=True broadcasts back to 3 channels.

ResolutionWrapper
ResolutionWrapper(env, scale=0.5)
Bases: _PixelTransform
Downsample the frame by scale then upsample back to the original size.
Noise Schedules
Time-dependent schedules. Each function returns a callable f(step) -> float that can be passed to NoiseWrapper(std=...). The wrapper increments step once per env.step call.
from stable_worldmodel.wrapper import NoiseWrapper, linear
env = NoiseWrapper(env, std=linear(0, 25, horizon=10_000))
constant
constant(value)
Schedule that always returns value.
linear
linear(start, end, horizon)
Linear ramp from start to end over horizon steps; held at end after.
cosine
cosine(start, end, horizon)
Cosine ramp from start to end over horizon steps; held at end after.
exponential
exponential(start, decay, floor=0.0)
Exponential decay start * decay**step, lower-bounded by floor.
sinusoidal
sinusoidal(low, high, period)
Sinusoid oscillating in [low, high] with the given period (in steps).