Custom environment wrappers


MegaWrapper

MegaWrapper(
    env: Env,
    image_shape: tuple[int, int] = (84, 84),
    pixels_transform: Callable[[Any], Any] | None = None,
    goal_transform: Callable[[Any], Any] | None = None,
    required_keys: Iterable[str] | None = None,
    separate_goal: bool = True,
    image_resample: str | int | None = None,
)

Bases: Wrapper

Combines multiple wrappers for comprehensive environment preprocessing.

Parameters:

  • env (Env) –

    The environment to wrap.

  • image_shape (tuple[int, int], default: (84, 84) ) –

    Target (height, width) for all image processing.

  • pixels_transform (Callable[[Any], Any] | None, default: None ) –

    Optional transform for rendered pixels.

  • goal_transform (Callable[[Any], Any] | None, default: None ) –

    Optional transform for goal images.

  • required_keys (Iterable[str] | None, default: None ) –

    Keys that must be present in info dict.

  • separate_goal (bool, default: True ) –

    Whether to handle goal separately.

  • image_resample (str | int | None, default: None ) –

    PIL resample mode used when resizing pixels and goal images. Accepts a PIL constant or a string in {'nearest','bilinear','bicubic','lanczos','box','hamming'}. Defaults to bilinear. Use 'nearest' for crisp pixel art.

EnsureInfoKeysWrapper

EnsureInfoKeysWrapper(
    env: Env, required_keys: Iterable[str]
)

Bases: Wrapper

Validates that required keys are present in the info dict.

Parameters:

  • env (Env) –

    The environment to wrap.

  • required_keys (Iterable[str]) –

    Iterable of regex patterns that must match keys in info.

EnsureImageShape

EnsureImageShape(
    env: Env, image_key: str, image_shape: tuple[int, int]
)

Bases: Wrapper

Validates that an image in the info dict has the expected spatial dimensions.

Parameters:

  • env (Env) –

    The environment to wrap.

  • image_key (str) –

    Key in info dict containing the image.

  • image_shape (tuple[int, int]) –

    Expected (height, width) of the image.

EnsureGoalInfoWrapper

EnsureGoalInfoWrapper(
    env: Env, check_reset: bool, check_step: bool = False
)

Bases: Wrapper

Validates that 'goal' key is present in info dict.

Parameters:

  • env (Env) –

    The environment to wrap.

  • check_reset (bool) –

    Whether to check 'goal' presence on reset.

  • check_step (bool, default: False ) –

    Whether to check 'goal' presence on each step.

EverythingToInfoWrapper

EverythingToInfoWrapper(env: Env)

Bases: Wrapper

Moves all transition information into the info dict.

Parameters:

  • env (Env) –

    The environment to wrap.

AddPixelsWrapper

AddPixelsWrapper(
    env: Env,
    pixels_shape: tuple[int, int] = (84, 84),
    torchvision_transform: Callable[[Any], Any]
    | None = None,
    resample: int | None = None,
)

Bases: Wrapper

Adds rendered environment pixels to info dict.

Parameters:

  • env (Env) –

    The environment to wrap.

  • pixels_shape (tuple[int, int], default: (84, 84) ) –

    Target (height, width) for rendered pixels.

  • torchvision_transform (Callable[[Any], Any] | None, default: None ) –

    Optional transform to apply to the pixels.

  • resample (int | None, default: None ) –

    PIL resample filter (e.g. Image.BILINEAR, Image.NEAREST). Defaults to BILINEAR.

ResizeGoalWrapper

ResizeGoalWrapper(
    env: Env,
    pixels_shape: tuple[int, int] = (84, 84),
    torchvision_transform: Callable[[Any], Any]
    | None = None,
    resample: int | None = None,
)

Bases: Wrapper

Resizes goal images in info dict.

Parameters:

  • env (Env) –

    The environment to wrap.

  • pixels_shape (tuple[int, int], default: (84, 84) ) –

    Target (height, width) for resizing goal images.

  • torchvision_transform (Callable[[Any], Any] | None, default: None ) –

    Optional transform to apply to goal images.

  • resample (int | None, default: None ) –

    PIL resample filter (e.g. Image.BILINEAR, Image.NEAREST). Defaults to BILINEAR.

Visual Wrappers

Visual wrappers operate on rendered frames (the output of env.render() and any info["pixels*"] entries). They are useful for adding distractors at evaluation time or augmentations at training time.

chromakey

ChromaKeyWrapper

ChromaKeyWrapper(env, key_color, media, tolerance=0.0)

Bases: Wrapper

Replace pixels matching a key color in rendered frames with an image or video background.

Works like a green-screen: pixels close to key_color (within tolerance) are swapped out for the corresponding pixels of media. If media is a video, frames advance and loop on each call to render.

noise

NoiseWrapper

NoiseWrapper(env, std=10.0, seed=None)

Bases: _PixelTransform

Add Gaussian pixel noise with a step-dependent standard deviation.

std is either a float or a callable f(step) -> float (e.g. linear, cosine, exponential, sinusoidal, or any user-provided function). The wrapper increments an internal step counter on each env.step call and passes the current count to the schedule before sampling noise.

colorjitter

ColorJitterWrapper

ColorJitterWrapper(
    env,
    brightness=0.2,
    contrast=0.2,
    saturation=0.2,
    hue=0.05,
    seed=None,
)

Bases: _PixelTransform

Random brightness, contrast, saturation, and hue shifts. Factors resample each reset.

blur

BlurWrapper

BlurWrapper(env, kernel=5, sigma=0.0)

Bases: _PixelTransform

Gaussian blur with odd kernel size and standard deviation sigma (0 = derived from kernel).

occlusion

OcclusionWrapper

OcclusionWrapper(
    env, num_patches=1, size=(0.1, 0.3), color=0, seed=None
)

Bases: _PixelTransform

Cover the frame with num_patches rectangles of fractional size. Patches resample per reset.

moving patches

MovingPatchWrapper

MovingPatchWrapper(
    env,
    num_patches=1,
    size=(0.1, 0.2),
    color=255,
    speed=2.0,
    seed=None,
)

Bases: _PixelTransform

Overlay num_patches solid-color rectangles that drift with their own velocity.

Each patch has an independent position and velocity sampled at reset. Positions advance by their velocity once per env.step and reflect off the frame edges, so motion is smooth and continuous (no teleporting). speed is the velocity magnitude in pixels per step.

random shift

RandomShiftWrapper

RandomShiftWrapper(env, pad=4, seed=None)

Bases: _PixelTransform

DrQ-style random shift: replicate-pad by pad pixels then random crop back. Resamples each call.

cutout

CutoutWrapper

CutoutWrapper(
    env, num=1, size=(0.1, 0.2), color=0, seed=None
)

Bases: _PixelTransform

Mask num random rectangles per frame with color. Resamples on every call.

random conv

RandomConvWrapper

RandomConvWrapper(env, kernel_size=3, seed=None)

Bases: _PixelTransform

Pass the frame through a randomly-initialized conv (3->3 channels). Weights resample per reset.

grayscale

GrayscaleWrapper

GrayscaleWrapper(env, keep_channels=True)

Bases: _PixelTransform

Convert frame to grayscale. keep_channels=True broadcasts back to 3 channels.

resolution

ResolutionWrapper

ResolutionWrapper(env, scale=0.5)

Bases: _PixelTransform

Downsample the frame by scale then upsample back to the original size.

Noise Schedules

Time-dependent schedules. Each function returns a callable f(step) -> float that can be passed to NoiseWrapper(std=...). The wrapper increments step once per env.step call.

from stable_worldmodel.wrapper import NoiseWrapper, linear

env = NoiseWrapper(env, std=linear(0, 25, horizon=10_000))

constant

constant(value)

Schedule that always returns value.

linear

linear(start, end, horizon)

Linear ramp from start to end over horizon steps; held at end after.

cosine

cosine(start, end, horizon)

Cosine ramp from start to end over horizon steps; held at end after.

exponential

exponential(start, decay, floor=0.0)

Exponential decay start * decay**step, lower-bounded by floor.

sinusoidal

sinusoidal(low, high, period)

Sinusoid oscillating in [low, high] with the given period (in steps).