Asynchronous
rex.asynchronous.AsyncGraph
¤
max_steps
property
¤
The maximum number of steps.
max_eps
property
¤
The maximum number of episodes.
__init__(nodes: Dict[str, BaseNode], supervisor: BaseNode, clock: Clock = Clock.WALL_CLOCK, real_time_factor: Union[float, int] = RealTimeFactor.REAL_TIME)
¤
Creates an interface around all nodes in the graph.
As a mental model, it helps to think of the graph as dividing the nodes into two groups:
- Supervisor Node: The designated node that controls the graph's execution flow.
- All Other Nodes: These nodes form the environment the supervisor interacts with.
This partitioning of nodes essentially creates an agent-environment interface, where the supervisor node acts as the
agent, and the remaining nodes represent the environment. The graph provides gym-like .reset
and .step
methods that
mirror reinforcement learning interfaces:
.init
: Initializes the graph state, which includes the state of all nodes..reset
: Initializes the system and returns the initial observation as would be seen by the supervisor node..step
: Advances the graph by one step (i.e. steps all nodes except the supervisor) and returns the next observation.
As a result, the timestep of graph.step is determined by the rate of the supervisor node (i.e., 1/supervisor.rate
).
Parameters:
-
nodes
(Dict[str, BaseNode]
) –Dictionary of nodes that make up the graph.
-
supervisor
(BaseNode
) –The designated node that controls the graph's execution flow.
-
clock
(Clock
, default:WALL_CLOCK
) –Determines how time is managed in the graph. Choices include
Clock.SIMULATED
for virtual simulations andClock.WALL_CLOCK
for real-time applications. -
real_time_factor
(Union[float, int]
, default:REAL_TIME
) –Sets the speed of the simulation. It can simulate as fast as possible (
RealTimeFactor.FAST_AS_POSSIBLE
), in real-time (RealTimeFactor.REAL_TIME
), or at any custom speed relative to real-time.
init(rng: jax.typing.ArrayLike = None, params: Dict[str, base.Params] = None, order: Tuple[str, ...] = None) -> base.GraphState
¤
Initializes the graph state with optional parameters for RNG and step states.
Nodes are initialized in a specified order, with the option to override params. Useful for setting up the graph state before running the graph with .run, .rollout, or .reset.
Parameters:
-
rng
(ArrayLike
, default:None
) –Random number generator seed or state.
-
params
(Dict[str, Params]
, default:None
) –Predefined params for (a subset of) the nodes.
-
order
(Tuple[str, ...]
, default:None
) –The order in which nodes are initialized.
Returns:
-
GraphState
–The initialized graph state.
set_record_settings(params: Union[Dict[str, bool], bool] = None, rng: Union[Dict[str, bool], bool] = None, inputs: Union[Dict[str, bool], bool] = None, state: Union[Dict[str, bool], bool] = None, output: Union[Dict[str, bool], bool] = None, max_records: Union[Dict[str, int], int] = None) -> None
¤
Sets the record settings for the nodes in the graph.
Parameters:
-
params
(Union[Dict[str, bool], bool]
, default:None
) –Whether to record the params of the nodes.
-
rng
(Union[Dict[str, bool], bool]
, default:None
) –Whether to record the RNG states of the nodes.
-
inputs
(Union[Dict[str, bool], bool]
, default:None
) –Whether to record the input states of the nodes.
-
state
(Union[Dict[str, bool], bool]
, default:None
) –Whether to record the state of the nodes.
-
output
(Union[Dict[str, bool], bool]
, default:None
) –Whether to record the output of the nodes.
-
max_records
(Union[Dict[str, int], int]
, default:None
) –The maximum number of records to store for each node.
warmup(graph_state: base.GraphState, device_step: Union[Dict[str, jax.Device], jax.Device] = None, device_dist: Union[Dict[str, jax.Device], jax.Device] = None, jit_step: Union[Dict[str, bool], bool] = True, profile: Union[Dict[str, bool], bool] = False, verbose: bool = False)
¤
Ahead-of-time compilation of step and I/O functions to avoid latency at runtime.
Parameters:
-
graph_state
(GraphState
) –The graph state that is expected to be used during runtime.
-
device_step
(Union[Dict[str, Device], Device]
, default:None
) –The device to compile the step functions on. It's also the device used to prepare the input states. If None, the default device is used.
-
device_dist
(Union[Dict[str, Device], Device]
, default:None
) –The device to compile the sampling of the delay distribution functions on. If None, the default device is used. Only relevant when using a simulated clock.
-
jit_step
(Union[Dict[str, bool], bool]
, default:True
) –Whether to compile the step functions with JIT. If True, the step functions are compiled with JIT. Step functions with jit are faster, but may not have side-effects by default. Either wrap the side-effecting code in a jax callback wrapper, or set jit=False for those nodes. See here for more info.
-
profile
(Union[Dict[str, bool], bool]
, default:False
) –Whether to compile the step functions with time profiling. If True, the step functions are compiled with time profiling. IMPORTANT: This will test-run the step functions, which may lead to unexpected side-effects.
-
verbose
(bool
, default:False
) –Whether to print time profiling information.
stop(timeout: float = None) -> None
¤
Stops the graph and all its nodes.
Parameters:
-
timeout
(float
, default:None
) –The maximum time to wait for the graph to stop. If None, it waits indefinitely.
run(graph_state: base.GraphState, timeout: float = None) -> base.GraphState
¤
Executes one step of the graph including the supervisor node and returns the updated graph state.
Different from the .step method, it automatically progresses the graph state post-supervisor execution. This method is different from the gym API, as it uses the .step method of the supervisor node, while the reset and step methods allow the user to override the .step method.
Parameters:
-
graph_state
(GraphState
) –The current graph state, or initial graph state from .init().
-
timeout
(float
, default:None
) –The maximum time to wait for the graph to complete a step.
Returns:
-
GraphState
–Updated graph state. It returns directly after the supervisor node's step() is run.
reset(graph_state: base.GraphState, timeout: float = None) -> Tuple[base.GraphState, base.StepState]
¤
Prepares the graph for execution by resetting it to a state before the supervisor node's execution.
Returns the graph and step state just before what would be the supervisor's step, mimicking the initial observation return of a gym environment's reset method. The step state can be considered the initial observation of a gym environment.
Parameters:
-
graph_state
(GraphState
) –The graph state from .init().
Returns:
-
Tuple[GraphState, StepState]
–Tuple of the new graph state and the supervisor node's step state before execution of the first step.
step(graph_state: base.GraphState, step_state: base.StepState = None, output: base.Output = None) -> Tuple[base.GraphState, base.StepState]
¤
Executes one step of the graph, optionally overriding the supervisor node's execution.
If step_state and output are provided, they override the supervisor's step, allowing for custom step implementations. Otherwise, the supervisor's step() is executed as usual.
When providing the updated step_state and output, the provided output can be viewed as the action that the agent would take in a gym environment, which is sent to nodes connected to the supervisor node.
Start every episode with a call to reset() using the initial graph state from init(), then call step() repeatedly.
Parameters:
-
graph_state
(GraphState
) –The current graph state.
-
step_state
(StepState
, default:None
) –Custom step state for the supervisor node.
-
output
(Output
, default:None
) –Custom output for the supervisor node.
Returns:
-
Tuple[GraphState, StepState]
–Tuple of the new graph state and the supervisor node's step state before execution of the next step.
get_record() -> base.EpisodeRecord
¤
Gets the episode record for all nodes in the graph.
Returns:
-
EpisodeRecord
–Returns the episode record for all nodes in the graph.