Base
rex.base.InputState
¤
A ring buffer that holds the inputs for a node's input channel.
The size of the buffer is determined by the window size of the corresponding connection (i.e. node.connect(..., window=...)).
Attributes:
-
seq
(ArrayLike
) –the sequence number of the received message
-
ts_sent
(ArrayLike
) –the time the message was sent
-
ts_recv
(ArrayLike
) –the time the message was received
-
data
(Output
) –the message of the connection (arbitrary pytree structure)
-
delay_dist
(DelayDistribution
) –the delay distribution of the connection
__getitem__(val: int) -> InputState
¤
Get the value of the ring buffer at a specific index.
This is useful for indexing all the values of the ring buffer at a specific index.
Parameters:
-
val
(int
) –the index to get the value from
Returns:
-
InputState
–The input state at the specific index
push(seq: int, ts_sent: float, ts_recv: float, data: Any) -> InputState
¤
Push a new message into the ring buffer.
Parameters:
-
seq
(int
) –the sequence number of the received message
-
ts_sent
(float
) –the time the message was sent
-
ts_recv
(float
) –the time the message was received
-
data
(Any
) –the message of the connection (arbitrary pytree structure)
Returns:
-
InputState
–The new input state with the message pushed into the ring buffer.
from_outputs(seq: ArrayLike, ts_sent: ArrayLike, ts_recv: ArrayLike, outputs: Any, delay_dist: DelayDistribution, is_data: bool = False) -> InputState
classmethod
¤
Create an InputState from a list of messages, timestamps, and sequence numbers.
The oldest message should be first in the list.
Parameters:
-
seq
(ArrayLike
) –the sequence number of the received message
-
ts_sent
(ArrayLike
) –the time the message was sent
-
ts_recv
(ArrayLike
) –the time the message was received
-
outputs
(Any
) –the messages of the connection (arbitrary pytree structure)
-
is_data
(bool
, default:False
) –if True, the outputs are already a stacked pytree structure
-
delay_dist
(DelayDistribution
) –the delay distribution of the connection
Returns:
-
InputState
–The input state with the messages and timestamps in the ring buffer.
rex.base.StepState
¤
Step state definition.
It holds all the information that is required to step a node.
Attributes:
-
rng
(Array
) –The random number generator. Used for sampling random processes. If used, it should be updated.
-
state
(State
) –The state of the node. Usually dynamic during an episode.
-
params
(Params
) –The parameters of the node. Usually static during an episode.
-
inputs
(FrozenDict[str, InputState]
) –The inputs of the node. See InputState.
-
eps
(Union[int, ArrayLike]
) –The current episode number. Relates to the computation graph, not the episode counter of an environment.
-
seq
(Union[int, ArrayLike]
) –The current step number. Automatically increases by 1 every step.
-
ts
(Union[float, ArrayLike]
) –The current time step at the start of the step. Determined by the computation graph.
rex.base.GraphState
¤
Graph state definition.
It holds all the information that is required to step a graph.
Attributes:
-
step
(Union[int, ArrayLike]
) –The current step number. Automatically increases by 1 every step.
-
eps
(Union[int, ArrayLike]
) –The current episode number. To update the episode, use GraphState.replace_eps.
-
rng
(FrozenDict[str, Array]
) –The random number generators for each node in the graph.
-
seq
(FrozenDict[str, Union[int, ArrayLike]]
) –The current step number for each node in the graph.
-
ts
(FrozenDict[str, Union[float, ArrayLike]]
) –The start time of the step for each node in the graph.
-
params
(FrozenDict[str, Params]
) –The parameters for each node in the graph.
-
state
(FrozenDict[str, State]
) –The state for each node in the graph.
-
inputs
(FrozenDict[str, FrozenDict[str, InputState]]
) –The inputs for each node in the graph.
-
timings_eps
(Timings
) –The timings data structure that describes the execution and partitioning of the graph.
-
buffer
(FrozenDict[str, Output]
) –The output buffer of the graph. It holds the outputs of nodes during the execution. Input buffers are automatically filled with the outputs of previously executed step calls of other nodes.
-
aux
(FrozenDict[str, Any]
) –Auxiliary data that can be used to store additional information (e.g. records, wrappers etc.).
rex.base.Base
¤
Base functionality extending all dataclasses. These methods allow for dataclasses to be operated like arrays/matrices.
Note: Credits to the authors of the brax library for this implementation.
Tip
Use this base class for all state, output, and param pytrees.
__str__()
¤
Return a string representation of the dataclass.
__add__(o: Any) -> Any
¤
Element-wise addition of two pytrees.
Parameters:
-
o
(Any
) –The other pytree to add.
Returns:
-
Any
–The resulting pytree after applying the element-wise operation.
__sub__(o: Any) -> Any
¤
Element-wise subtraction of two pytrees.
Parameters:
-
o
(Any
) –The other pytree to add.
Returns:
-
Any
–The resulting pytree after applying the element-wise operation.
__mul__(o: Any) -> Any
¤
Element-wise multiplication of two pytrees.
Parameters:
-
o
(Any
) –The other pytree to add.
Returns:
-
Any
–The resulting pytree after applying the element-wise operation.
__neg__() -> Any
¤
Element-wise negation of two pytrees.
Returns:
-
Any
–The resulting pytree after applying the element-wise operation.
__truediv__(o: Any) -> Any
¤
Element-wise division of two pytrees.
Parameters:
-
o
(Any
) –The other pytree to add.
Returns:
-
Any
–The resulting pytree after applying the element-wise operation.
__getitem__(val: int) -> Any
¤
Get a specific value from the dataclass.
Parameters:
-
val
(int
) –The value to get from the dataclass.
Returns:
-
Any
–The value from the dataclass.
replace(*args: Any, **kwargs: Any) -> Any
¤
Replace fields in the dataclass.
Parameters:
-
*args
(Any
, default:()
) –The fields to replace.
-
**kwargs
(Any
, default:{}
) –The fields to replace.
reshape(shape: Sequence[int]) -> Any
¤
Reshape the dataclass.
Parameters:
-
shape
(Sequence[int]
) –The shape to reshape the dataclass.
Returns:
-
Any
–The reshaped dataclass.
select(o: Any, cond: jax.Array) -> Any
¤
Select elements from two pytrees based on a condition
Parameters:
-
o
(Any
) –The other pytree to select elements from.
-
cond
(Array
) –The condition to select elements based on.
Returns:
-
Any
–The resulting pytree after applying the condition
slice(beg: int, end: int) -> Any
¤
Slice the dataclass.
Parameters:
-
beg
(int
) –The beginning of the slice.
-
end
(int
) –The end of the slice.
Returns:
-
Any
–The sliced dataclass.
take(i: int, axis: int = 0) -> Any
¤
Take elements from the dataclass.
Parameters:
-
i
(int
) –The elements to take.
-
axis
(int
, default:0
) –The axis to take the elements from.
Returns:
-
Any
–The taken elements from the dataclass.
concatenate(*others: Any, axis: int = 0) -> Any
¤
Concatenate the dataclass with other dataclasses.
Parameters:
-
*others
(Any
, default:()
) –The other dataclasses to concatenate.
-
axis
(int
, default:0
) –The axis to concatenate the dataclasses on.
Returns:
-
Any
–The concatenated dataclass
index_set(idx: Union[jax.Array, Sequence[jax.Array]], o: Any) -> Any
¤
Set elements in the dataclass based on an index.
Parameters:
-
idx
(Union[Array, Sequence[Array]]
) –The index to set the elements.
-
o
(Any
) –The elements to set.
Returns:
-
Any
–The dataclass with the elements
index_sum(idx: Union[jax.Array, Sequence[jax.Array]], o: Any) -> Any
¤
Sum elements in the dataclass based on an index.
Parameters:
-
idx
(Union[Array, Sequence[Array]]
) –The index to sum the elements.
-
o
(Any
) –The elements to sum.
Returns:
-
Any
–The dataclass with the summed elements.