Environment and Wrappers¤
rex.rl.BaseEnv
¤
graph = graph
instance-attribute
¤
max_steps: Union[int, jax.typing.ArrayLike]
property
¤
The maximum number of steps in the environment.
Per default, this is the maximum number of steps the supervisor (i.e. agent) is stepped in the provided computation graph. You can override this property to provide a custom value (smaller than the default). This value is used as the episode length when evaluating the environment during training.
observation_space(graph_state: base.GraphState) -> Box
¤
Returns the observation space.
Parameters:
-
graph_state
(GraphState
) –The graph state.
Returns:
-
Box
–The observation space
action_space(graph_state: base.GraphState) -> Box
¤
Returns the action space.
Parameters:
-
graph_state
(GraphState
) –The graph state.
Returns:
-
Box
–The action space
reset(rng: jax.Array = None) -> ResetReturn
¤
Reset the environment.
Parameters:
-
rng
(Array
, default:None
) –Random number generator. Used to initialize a new graph state.
Returns:
-
ResetReturn
–The initial graph state, observation, and info
step(graph_state: base.GraphState, action: jax.Array) -> StepReturn
¤
Step the environment.
Parameters:
-
graph_state
(GraphState
) –The current graph state.
-
action
(Array
) –The action to take.
Returns:
-
StepReturn
–The updated graph state, observation, reward, terminated, truncated, and info
rex.rl.BaseWrapper
¤
Bases: object
Base class for wrappers.
__init__(env: Union[BaseEnv, Environment, BaseWrapper])
¤
Initialize the wrapper.
Parameters:
-
env
(Union[BaseEnv, Environment, BaseWrapper]
) –The environment to wrap.
__getattr__(name: str) -> Any
¤
Proxy access to regular attributes of wrapped object.
Parameters:
-
name
(str
) –The name of the attribute.
Returns:
-
Any
–The attribute of the wrapped object.
rex.rl.AutoResetWrapper
¤
Bases: BaseWrapper
__init__(env: Union[BaseEnv, Environment, BaseWrapper], fixed_init: bool = True)
¤
The AutoResetWrapper will reset the environment when the episode is done in the step method.
When fixed_init is True, a fixed initial state is used for the environment instead of actually resetting it. This is useful when you want to use the same initial state for every episode. In some cases, resetting the environment can be expensive, so this can be used to avoid that.
Parameters:
-
env
(Union[BaseEnv, Environment, BaseWrapper]
) –The environment to wrap.
-
fixed_init
(bool
, default:True
) –Whether to use a fixed initial state.
reset(rng: jax.Array = None) -> ResetReturn
¤
Reset the environment and return the initial state.
If fixed_init is True, the initial state is stored in the aux of the graph state.
Parameters:
-
rng
(Array
, default:None
) –Random number generator.
Returns:
-
ResetReturn
–The initial graph state, observation, and info
step(graph_state: base.GraphState, action: jax.Array) -> StepReturn
¤
Step the environment and reset the state if the episode is done.
Parameters:
-
graph_state
(GraphState
) –The current graph state.
-
action
(Array
) –The action to take.
Returns:
-
StepReturn
–The updated graph state, observation, reward, terminated, truncated, and info
rex.rl.LogWrapper
¤
Bases: BaseWrapper
__init__(env: Union[BaseEnv, Environment, BaseWrapper])
¤
Log the episode returns and lengths.
Parameters:
-
env
(Union[BaseEnv, Environment, BaseWrapper]
) –The environment to wrap.
reset(rng: jax.Array = None) -> ResetReturn
¤
Stores the log state in the aux of the graph state.
Parameters:
-
rng
(Array
, default:None
) –Random number generator.
Returns:
-
ResetReturn
–The initial graph state, observation, and info
step(graph_state: base.GraphState, action: jax.Array) -> StepReturn
¤
Logs the episode returns and lengths.
Parameters:
-
graph_state
(GraphState
) –The current graph state.
-
action
(Array
) –The action to take.
Returns:
-
StepReturn
–The updated graph state, observation, reward, terminated, truncated, and info
rex.rl.LogState
¤
Attributes:
-
episode_returns
(float
) –The sum of the rewards in the episode.
-
episode_lengths
(int
) –The number of steps in the episode.
-
returned_episode_returns
(float
) –The sum of the rewards in the episode that was returned.
-
returned_episode_lengths
(int
) –The number of steps in the episode that was returned.
-
timestep
(int
) –The current
rex.rl.SquashActionWrapper
¤
Bases: BaseWrapper
__init__(env: Union[BaseEnv, Environment, BaseWrapper], squash: bool = True)
¤
Squashes the action space to [-1, 1] and unsquashes it when returning the action.
Parameters:
-
env
(Union[BaseEnv, Environment, BaseWrapper]
) –The environment to wrap.
-
squash
(bool
, default:True
) –Whether to squash the action space.
reset(rng: jax.Array = None) -> ResetReturn
¤
Puts the action space scaling in the aux of the graph state.
Parameters:
-
rng
(Array
, default:None
) –Random number generator.
Returns:
-
ResetReturn
–The initial graph state, observation, and info
step(graph_state: base.GraphState, action: jax.Array) -> StepReturn
¤
Unscales the action to the original range of the action space before stepping the environment.
Parameters:
-
graph_state
(GraphState
) –The current graph state.
-
action
(Array
) –The (scaled) action to take.
Returns:
-
StepReturn
–The updated graph state, observation, reward, terminated, truncated, and info
action_space(graph_state: base.GraphState) -> Box
¤
Scales the action space to [-1, 1] if squash is True.
Parameters:
-
graph_state
(GraphState
) –The graph state.
Returns:
-
Box
–The scaled action space
rex.rl.SquashState
¤
Attributes:
-
low
(Array
) –The lower bound of the action space.
-
high
(Array
) –The upper bound of the action space.
-
squash
(bool
) –Whether to squash the action space.
action_space: Box
property
¤
Returns:
-
Box
–The scaled action space.
scale(x: jax.Array) -> jax.Array
¤
Scales the input to [-1, 1] and unsquashes.
Parameters:
-
x
(Array
) –The input to scale.
Returns:
-
Array
–The scaled input.
unsquash(x: jax.Array) -> jax.Array
¤
Squashes x to [-1, 1] and then unscales to the original range [low, high]. Else, x is clipped to the range of the action space.
Parameters:
-
x
(Array
) –The input to unscale.
Returns:
-
Array
–Unscaled input.
rex.rl.ClipActionWrapper
¤
Bases: BaseWrapper
step(graph_state: base.GraphState, action: jax.Array) -> StepReturn
¤
Clips the action to the action space before stepping the environment.
Parameters:
-
graph_state
(GraphState
) –The current graph state.
-
action
(Array
) –The action to take.
Returns:
-
StepReturn
–The updated graph state, observation, reward, terminated, truncated
rex.rl.VecEnvWrapper
¤
Bases: BaseWrapper
__init__(env: Union[BaseEnv, Environment, BaseWrapper], in_axes: Union[int, None, Sequence[Any]] = 0)
¤
Vectorizes the environment.
Parameters:
-
env
(Union[BaseEnv, Environment, BaseWrapper]
) –The environment to wrap.
-
in_axes
(Union[int, None, Sequence[Any]]
, default:0
) –The axes to map over.
rex.rl.NormalizeVecObservationWrapper
¤
Bases: BaseWrapper
__init__(env: Union[BaseEnv, Environment, BaseWrapper], clip_obs: float = 10.0)
¤
Normalize the observations to have zero mean and unit variance.
Parameters:
-
env
(Union[BaseEnv, Environment, BaseWrapper]
) –The environment to wrap.
-
clip_obs
(float
, default:10.0
) –The clipping value.
reset(rng: jax.Array = None) -> ResetReturn
¤
Places the normalization state in the aux of the graph state.
Parameters:
-
rng
(Array
, default:None
) –Random number generator.
Returns:
-
ResetReturn
–The initial graph state, observation, and info
step(graph_state: base.GraphState, action: jax.Array) -> StepReturn
¤
Normalize the observations to have zero mean and unit variance before returning them.
Parameters:
-
graph_state
(GraphState
) –The current graph state.
-
action
(Array
) –The action to take.
Returns:
-
StepReturn
–The updated graph state, observation, reward, terminated, truncated, and info
rex.rl.NormalizeVecReward
¤
Bases: BaseWrapper
__init__(env: Union[BaseEnv, Environment, BaseWrapper], gamma: Union[float, jax.typing.ArrayLike], clip_reward: float = 10.0)
¤
Normalize the rewards to have zero mean and unit variance.
Parameters:
-
env
(Union[BaseEnv, Environment, BaseWrapper]
) –The environment to wrap.
-
gamma
(Union[float, ArrayLike]
) –The discount factor.
-
clip_reward
(float
, default:10.0
) –The clipping value.
reset(rng: jax.Array = None) -> ResetReturn
¤
Places the normalization state in the aux of the graph state.
Parameters:
-
rng
(Array
, default:None
) –Random number generator.
Returns:
-
ResetReturn
–The initial graph state, observation, and info
step(graph_state: base.GraphState, action: jax.Array) -> StepReturn
¤
Normalize the rewards to have zero mean and unit variance before returning them.
Parameters:
-
graph_state
(GraphState
) –The current graph state.
-
action
(Array
) –The action to take.
Returns:
-
StepReturn
–The updated graph state, observation, reward, terminated, truncated, and info
rex.rl.NormalizeVec
¤
Attributes mean: The mean of the observations. var: The variance of the observations. count: The number of observations. return_val: The return value. clip: The clipping value.
normalize(x: jax.Array, clip: bool = True, subtract_mean: bool = True) -> jax.Array
¤
Normalize x to have zero mean and unit variance.
Parameters:
-
x
(Array
) –The input to normalize.
-
clip
(bool
, default:True
) –Whether to clip the input.
-
subtract_mean
(bool
, default:True
) –Whether to subtract the mean.
Returns:
-
Array
–The normalized input.
denormalize(x: jax.Array, add_mean: bool = True) -> jax.Array
¤
Denormalize x to have the original mean and variance.
Parameters:
-
x
(Array
) –The input to denormalize.
-
add_mean
(bool
, default:True
) –Whether to add the mean.
Returns:
-
Array
–The denormalized input.