Skip to content

Environment and Wrappers¤

rex.rl.BaseEnv ¤

graph = graph instance-attribute ¤
max_steps: Union[int, jax.typing.ArrayLike] property ¤

The maximum number of steps in the environment.

Per default, this is the maximum number of steps the supervisor (i.e. agent) is stepped in the provided computation graph. You can override this property to provide a custom value (smaller than the default). This value is used as the episode length when evaluating the environment during training.

observation_space(graph_state: base.GraphState) -> Box ¤

Returns the observation space.

Parameters:

Returns:

  • Box

    The observation space

action_space(graph_state: base.GraphState) -> Box ¤

Returns the action space.

Parameters:

Returns:

  • Box

    The action space

reset(rng: jax.Array = None) -> ResetReturn ¤

Reset the environment.

Parameters:

  • rng (Array, default: None ) –

    Random number generator. Used to initialize a new graph state.

Returns:

  • ResetReturn

    The initial graph state, observation, and info

step(graph_state: base.GraphState, action: jax.Array) -> StepReturn ¤

Step the environment.

Parameters:

  • graph_state (GraphState) –

    The current graph state.

  • action (Array) –

    The action to take.

Returns:

  • StepReturn

    The updated graph state, observation, reward, terminated, truncated, and info


rex.rl.BaseWrapper ¤

Bases: object

Base class for wrappers.

__init__(env: Union[BaseEnv, Environment, BaseWrapper]) ¤

Initialize the wrapper.

Parameters:

__getattr__(name: str) -> Any ¤

Proxy access to regular attributes of wrapped object.

Parameters:

  • name (str) –

    The name of the attribute.

Returns:

  • Any

    The attribute of the wrapped object.


rex.rl.AutoResetWrapper ¤

Bases: BaseWrapper

__init__(env: Union[BaseEnv, Environment, BaseWrapper], fixed_init: bool = True) ¤

The AutoResetWrapper will reset the environment when the episode is done in the step method.

When fixed_init is True, a fixed initial state is used for the environment instead of actually resetting it. This is useful when you want to use the same initial state for every episode. In some cases, resetting the environment can be expensive, so this can be used to avoid that.

Parameters:

  • env (Union[BaseEnv, Environment, BaseWrapper]) –

    The environment to wrap.

  • fixed_init (bool, default: True ) –

    Whether to use a fixed initial state.

reset(rng: jax.Array = None) -> ResetReturn ¤

Reset the environment and return the initial state.

If fixed_init is True, the initial state is stored in the aux of the graph state.

Parameters:

  • rng (Array, default: None ) –

    Random number generator.

Returns:

  • ResetReturn

    The initial graph state, observation, and info

step(graph_state: base.GraphState, action: jax.Array) -> StepReturn ¤

Step the environment and reset the state if the episode is done.

Parameters:

  • graph_state (GraphState) –

    The current graph state.

  • action (Array) –

    The action to take.

Returns:

  • StepReturn

    The updated graph state, observation, reward, terminated, truncated, and info


rex.rl.LogWrapper ¤

Bases: BaseWrapper

__init__(env: Union[BaseEnv, Environment, BaseWrapper]) ¤

Log the episode returns and lengths.

Parameters:

reset(rng: jax.Array = None) -> ResetReturn ¤

Stores the log state in the aux of the graph state.

Parameters:

  • rng (Array, default: None ) –

    Random number generator.

Returns:

  • ResetReturn

    The initial graph state, observation, and info

step(graph_state: base.GraphState, action: jax.Array) -> StepReturn ¤

Logs the episode returns and lengths.

Parameters:

  • graph_state (GraphState) –

    The current graph state.

  • action (Array) –

    The action to take.

Returns:

  • StepReturn

    The updated graph state, observation, reward, terminated, truncated, and info

rex.rl.LogState ¤

Attributes:

  • episode_returns (float) –

    The sum of the rewards in the episode.

  • episode_lengths (int) –

    The number of steps in the episode.

  • returned_episode_returns (float) –

    The sum of the rewards in the episode that was returned.

  • returned_episode_lengths (int) –

    The number of steps in the episode that was returned.

  • timestep (int) –

    The current


rex.rl.SquashActionWrapper ¤

Bases: BaseWrapper

__init__(env: Union[BaseEnv, Environment, BaseWrapper], squash: bool = True) ¤

Squashes the action space to [-1, 1] and unsquashes it when returning the action.

Parameters:

  • env (Union[BaseEnv, Environment, BaseWrapper]) –

    The environment to wrap.

  • squash (bool, default: True ) –

    Whether to squash the action space.

reset(rng: jax.Array = None) -> ResetReturn ¤

Puts the action space scaling in the aux of the graph state.

Parameters:

  • rng (Array, default: None ) –

    Random number generator.

Returns:

  • ResetReturn

    The initial graph state, observation, and info

step(graph_state: base.GraphState, action: jax.Array) -> StepReturn ¤

Unscales the action to the original range of the action space before stepping the environment.

Parameters:

  • graph_state (GraphState) –

    The current graph state.

  • action (Array) –

    The (scaled) action to take.

Returns:

  • StepReturn

    The updated graph state, observation, reward, terminated, truncated, and info

action_space(graph_state: base.GraphState) -> Box ¤

Scales the action space to [-1, 1] if squash is True.

Parameters:

Returns:

  • Box

    The scaled action space

rex.rl.SquashState ¤

Attributes:

  • low (Array) –

    The lower bound of the action space.

  • high (Array) –

    The upper bound of the action space.

  • squash (bool) –

    Whether to squash the action space.

action_space: Box property ¤

Returns:

  • Box

    The scaled action space.

scale(x: jax.Array) -> jax.Array ¤

Scales the input to [-1, 1] and unsquashes.

Parameters:

  • x (Array) –

    The input to scale.

Returns:

  • Array

    The scaled input.

unsquash(x: jax.Array) -> jax.Array ¤

Squashes x to [-1, 1] and then unscales to the original range [low, high]. Else, x is clipped to the range of the action space.

Parameters:

  • x (Array) –

    The input to unscale.

Returns:

  • Array

    Unscaled input.


rex.rl.ClipActionWrapper ¤

Bases: BaseWrapper

step(graph_state: base.GraphState, action: jax.Array) -> StepReturn ¤

Clips the action to the action space before stepping the environment.

Parameters:

  • graph_state (GraphState) –

    The current graph state.

  • action (Array) –

    The action to take.

Returns:

  • StepReturn

    The updated graph state, observation, reward, terminated, truncated


rex.rl.VecEnvWrapper ¤

Bases: BaseWrapper

__init__(env: Union[BaseEnv, Environment, BaseWrapper], in_axes: Union[int, None, Sequence[Any]] = 0) ¤

Vectorizes the environment.

Parameters:

  • env (Union[BaseEnv, Environment, BaseWrapper]) –

    The environment to wrap.

  • in_axes (Union[int, None, Sequence[Any]], default: 0 ) –

    The axes to map over.


rex.rl.NormalizeVecObservationWrapper ¤

Bases: BaseWrapper

__init__(env: Union[BaseEnv, Environment, BaseWrapper], clip_obs: float = 10.0) ¤

Normalize the observations to have zero mean and unit variance.

Parameters:

  • env (Union[BaseEnv, Environment, BaseWrapper]) –

    The environment to wrap.

  • clip_obs (float, default: 10.0 ) –

    The clipping value.

reset(rng: jax.Array = None) -> ResetReturn ¤

Places the normalization state in the aux of the graph state.

Parameters:

  • rng (Array, default: None ) –

    Random number generator.

Returns:

  • ResetReturn

    The initial graph state, observation, and info

step(graph_state: base.GraphState, action: jax.Array) -> StepReturn ¤

Normalize the observations to have zero mean and unit variance before returning them.

Parameters:

  • graph_state (GraphState) –

    The current graph state.

  • action (Array) –

    The action to take.

Returns:

  • StepReturn

    The updated graph state, observation, reward, terminated, truncated, and info

rex.rl.NormalizeVecReward ¤

Bases: BaseWrapper

__init__(env: Union[BaseEnv, Environment, BaseWrapper], gamma: Union[float, jax.typing.ArrayLike], clip_reward: float = 10.0) ¤

Normalize the rewards to have zero mean and unit variance.

Parameters:

  • env (Union[BaseEnv, Environment, BaseWrapper]) –

    The environment to wrap.

  • gamma (Union[float, ArrayLike]) –

    The discount factor.

  • clip_reward (float, default: 10.0 ) –

    The clipping value.

reset(rng: jax.Array = None) -> ResetReturn ¤

Places the normalization state in the aux of the graph state.

Parameters:

  • rng (Array, default: None ) –

    Random number generator.

Returns:

  • ResetReturn

    The initial graph state, observation, and info

step(graph_state: base.GraphState, action: jax.Array) -> StepReturn ¤

Normalize the rewards to have zero mean and unit variance before returning them.

Parameters:

  • graph_state (GraphState) –

    The current graph state.

  • action (Array) –

    The action to take.

Returns:

  • StepReturn

    The updated graph state, observation, reward, terminated, truncated, and info

rex.rl.NormalizeVec ¤

Attributes mean: The mean of the observations. var: The variance of the observations. count: The number of observations. return_val: The return value. clip: The clipping value.

normalize(x: jax.Array, clip: bool = True, subtract_mean: bool = True) -> jax.Array ¤

Normalize x to have zero mean and unit variance.

Parameters:

  • x (Array) –

    The input to normalize.

  • clip (bool, default: True ) –

    Whether to clip the input.

  • subtract_mean (bool, default: True ) –

    Whether to subtract the mean.

Returns:

  • Array

    The normalized input.

denormalize(x: jax.Array, add_mean: bool = True) -> jax.Array ¤

Denormalize x to have the original mean and variance.

Parameters:

  • x (Array) –

    The input to denormalize.

  • add_mean (bool, default: True ) –

    Whether to add the mean.

Returns:

  • Array

    The denormalized input.