Skip to content

Base

rex.base.InputState ¤

A ring buffer that holds the inputs for a node's input channel.

The size of the buffer is determined by the window size of the corresponding connection (i.e. node.connect(..., window=...)).

Attributes:

  • seq (ArrayLike) –

    the sequence number of the received message

  • ts_sent (ArrayLike) –

    the time the message was sent

  • ts_recv (ArrayLike) –

    the time the message was received

  • data (Output) –

    the message of the connection (arbitrary pytree structure)

  • delay_dist (DelayDistribution) –

    the delay distribution of the connection

__getitem__(val: int) -> InputState ¤

Get the value of the ring buffer at a specific index.

This is useful for indexing all the values of the ring buffer at a specific index.

Parameters:

  • val (int) –

    the index to get the value from

Returns:

  • InputState

    The input state at the specific index

push(seq: int, ts_sent: float, ts_recv: float, data: Any) -> InputState ¤

Push a new message into the ring buffer.

Parameters:

  • seq (int) –

    the sequence number of the received message

  • ts_sent (float) –

    the time the message was sent

  • ts_recv (float) –

    the time the message was received

  • data (Any) –

    the message of the connection (arbitrary pytree structure)

Returns:

  • InputState

    The new input state with the message pushed into the ring buffer.

from_outputs(seq: ArrayLike, ts_sent: ArrayLike, ts_recv: ArrayLike, outputs: Any, delay_dist: DelayDistribution, is_data: bool = False) -> InputState classmethod ¤

Create an InputState from a list of messages, timestamps, and sequence numbers.

The oldest message should be first in the list.

Parameters:

  • seq (ArrayLike) –

    the sequence number of the received message

  • ts_sent (ArrayLike) –

    the time the message was sent

  • ts_recv (ArrayLike) –

    the time the message was received

  • outputs (Any) –

    the messages of the connection (arbitrary pytree structure)

  • is_data (bool, default: False ) –

    if True, the outputs are already a stacked pytree structure

  • delay_dist (DelayDistribution) –

    the delay distribution of the connection

Returns:

  • InputState

    The input state with the messages and timestamps in the ring buffer.


rex.base.StepState ¤

Step state definition.

It holds all the information that is required to step a node.

Attributes:

  • rng (Array) –

    The random number generator. Used for sampling random processes. If used, it should be updated.

  • state (State) –

    The state of the node. Usually dynamic during an episode.

  • params (Params) –

    The parameters of the node. Usually static during an episode.

  • inputs (FrozenDict[str, InputState]) –

    The inputs of the node. See InputState.

  • eps (Union[int, ArrayLike]) –

    The current episode number. Relates to the computation graph, not the episode counter of an environment.

  • seq (Union[int, ArrayLike]) –

    The current step number. Automatically increases by 1 every step.

  • ts (Union[float, ArrayLike]) –

    The current time step at the start of the step. Determined by the computation graph.


rex.base.GraphState ¤

Graph state definition.

It holds all the information that is required to step a graph.

Attributes:

  • step (Union[int, ArrayLike]) –

    The current step number. Automatically increases by 1 every step.

  • eps (Union[int, ArrayLike]) –

    The current episode number. To update the episode, use GraphState.replace_eps.

  • rng (FrozenDict[str, Array]) –

    The random number generators for each node in the graph.

  • seq (FrozenDict[str, Union[int, ArrayLike]]) –

    The current step number for each node in the graph.

  • ts (FrozenDict[str, Union[float, ArrayLike]]) –

    The start time of the step for each node in the graph.

  • params (FrozenDict[str, Params]) –

    The parameters for each node in the graph.

  • state (FrozenDict[str, State]) –

    The state for each node in the graph.

  • inputs (FrozenDict[str, FrozenDict[str, InputState]]) –

    The inputs for each node in the graph.

  • timings_eps (Timings) –

    The timings data structure that describes the execution and partitioning of the graph.

  • buffer (FrozenDict[str, Output]) –

    The output buffer of the graph. It holds the outputs of nodes during the execution. Input buffers are automatically filled with the outputs of previously executed step calls of other nodes.

  • aux (FrozenDict[str, Any]) –

    Auxiliary data that can be used to store additional information (e.g. records, wrappers etc.).


rex.base.Base ¤

Base functionality extending all dataclasses. These methods allow for dataclasses to be operated like arrays/matrices.

Note: Credits to the authors of the brax library for this implementation.

Tip

Use this base class for all state, output, and param pytrees.

__str__() ¤

Return a string representation of the dataclass.

__add__(o: Any) -> Any ¤

Element-wise addition of two pytrees.

Parameters:

  • o (Any) –

    The other pytree to add.

Returns:

  • Any

    The resulting pytree after applying the element-wise operation.

__sub__(o: Any) -> Any ¤

Element-wise subtraction of two pytrees.

Parameters:

  • o (Any) –

    The other pytree to add.

Returns:

  • Any

    The resulting pytree after applying the element-wise operation.

__mul__(o: Any) -> Any ¤

Element-wise multiplication of two pytrees.

Parameters:

  • o (Any) –

    The other pytree to add.

Returns:

  • Any

    The resulting pytree after applying the element-wise operation.

__neg__() -> Any ¤

Element-wise negation of two pytrees.

Returns:

  • Any

    The resulting pytree after applying the element-wise operation.

__truediv__(o: Any) -> Any ¤

Element-wise division of two pytrees.

Parameters:

  • o (Any) –

    The other pytree to add.

Returns:

  • Any

    The resulting pytree after applying the element-wise operation.

__getitem__(val: int) -> Any ¤

Get a specific value from the dataclass.

Parameters:

  • val (int) –

    The value to get from the dataclass.

Returns:

  • Any

    The value from the dataclass.

replace(*args: Any, **kwargs: Any) -> Any ¤

Replace fields in the dataclass.

Parameters:

  • *args (Any, default: () ) –

    The fields to replace.

  • **kwargs (Any, default: {} ) –

    The fields to replace.

reshape(shape: Sequence[int]) -> Any ¤

Reshape the dataclass.

Parameters:

  • shape (Sequence[int]) –

    The shape to reshape the dataclass.

Returns:

  • Any

    The reshaped dataclass.

select(o: Any, cond: jax.Array) -> Any ¤

Select elements from two pytrees based on a condition

Parameters:

  • o (Any) –

    The other pytree to select elements from.

  • cond (Array) –

    The condition to select elements based on.

Returns:

  • Any

    The resulting pytree after applying the condition

slice(beg: int, end: int) -> Any ¤

Slice the dataclass.

Parameters:

  • beg (int) –

    The beginning of the slice.

  • end (int) –

    The end of the slice.

Returns:

  • Any

    The sliced dataclass.

take(i: int, axis: int = 0) -> Any ¤

Take elements from the dataclass.

Parameters:

  • i (int) –

    The elements to take.

  • axis (int, default: 0 ) –

    The axis to take the elements from.

Returns:

  • Any

    The taken elements from the dataclass.

concatenate(*others: Any, axis: int = 0) -> Any ¤

Concatenate the dataclass with other dataclasses.

Parameters:

  • *others (Any, default: () ) –

    The other dataclasses to concatenate.

  • axis (int, default: 0 ) –

    The axis to concatenate the dataclasses on.

Returns:

  • Any

    The concatenated dataclass

index_set(idx: Union[jax.Array, Sequence[jax.Array]], o: Any) -> Any ¤

Set elements in the dataclass based on an index.

Parameters:

  • idx (Union[Array, Sequence[Array]]) –

    The index to set the elements.

  • o (Any) –

    The elements to set.

Returns:

  • Any

    The dataclass with the elements

index_sum(idx: Union[jax.Array, Sequence[jax.Array]], o: Any) -> Any ¤

Sum elements in the dataclass based on an index.

Parameters:

  • idx (Union[Array, Sequence[Array]]) –

    The index to sum the elements.

  • o (Any) –

    The elements to sum.

Returns:

  • Any

    The dataclass with the summed elements.