Skip to content

Compiled

rex.graph.Graph ¤

max_steps property ¤

The maximum number of steps.

That's usually the number of vertices of the supervisor in the raw computation graphs.

timings: base.Timings property ¤

Timings of the supergraph.

Contains all predication masks to convert the supergraph to the correct partition given the current episode and step.

graphs: base.Graph property ¤

Graphs after applying windows to the raw computation graphs.

graphs_raw: base.Graph property ¤

Raw computation graphs.

Gs: List[nx.DiGraph] property ¤

List of networkx graphs after applying windows to the raw computation graphs.

S: nx.DiGraph property ¤

The supergraph

__init__(nodes: Dict[str, BaseNode], supervisor: BaseNode, graphs_raw: base.Graph, skip: List[str] = None, supergraph: Supergraph = Supergraph.MCS, prune: bool = True, S_init: nx.DiGraph = None, backtrack: int = 20, debug: bool = False, progress_bar: bool = True, buffer_sizes: Dict[str, int] = None, extra_padding: int = 0) ¤

Compile graph with nodes, supervisor, and target computation graphs.

This class finds a partitioning and supergraph to efficiently represent all raw computation graphs. It exposes a .step and .reset method that resembles the gym API. In addition, we provide a .run and .rollout method. We refer to the specific methods for more information.

The supervisor node defines the boundary between partitions, and essentially dictates the timestep of every step call.

"Raw" computation graphs are the graphs that are computation graphs that only take into account the data flow of a system, without considering the fact that some messages may be used in multiple step calls, when no new data is available. Conversely, some messages may be discarded if they fall out of the buffer size. In other words, we first modify the raw computation graphs to take into account the buffer sizes (i.e. window sizes) for every connection.

Parameters:

  • nodes (Dict[str, BaseNode]) –

    Dictionary of nodes.

  • supervisor (BaseNode) –

    Supervisor node.

  • graphs_raw (Graph) –

    Raw computation graphs. Must be acyclic.

  • skip (List[str], default: None ) –

    List of nodes to skip during graph execution.

  • supergraph (Supergraph, default: MCS ) –

    Supergraph mode. Options are MCS, TOPOLOGICAL, and GENERATIONAL.

  • prune (bool, default: True ) –

    Prune nodes that are no ancestors of the supervisor node. Setting to False ensures that all nodes up until the time of the last supervisor node are included.

  • S_init (DiGraph, default: None ) –

    Initial supergraph.

  • backtrack (int, default: 20 ) –

    Backtrack parameter for MCS supergraph mode.

  • debug (bool, default: False ) –

    Debug mode. Validates the partitioning and supergraph and times various compilation steps.

  • progress_bar (bool, default: True ) –

    Show progress bar during supergraph generation.

  • buffer_sizes (Dict[str, int], default: None ) –

    Dictionary of buffer sizes for each connection.

  • extra_padding (int, default: 0 ) –

    Extra padding for buffer sizes.

init(rng: jax.typing.ArrayLike = None, params: Dict[str, base.Params] = None, starting_step: Union[int, jax.typing.ArrayLike] = 0, starting_eps: jax.typing.ArrayLike = 0, randomize_eps: bool = False, order: Tuple[str, ...] = None) -> base.GraphState ¤

Initializes the graph state with optional parameters for RNG and step states.

Nodes are initialized in a specified order, with the option to override params. Useful for setting up the graph state before running the graph with .run, .rollout, or .reset.

Parameters:

  • rng (ArrayLike, default: None ) –

    Random number generator seed or state.

  • params (Dict[str, Params], default: None ) –

    Predefined params for (a subset of) the nodes.

  • starting_step (Union[int, ArrayLike], default: 0 ) –

    The simulation's starting step.

  • starting_eps (ArrayLike, default: 0 ) –

    The starting episode.

  • randomize_eps (bool, default: False ) –

    If True, randomly selects the starting episode.

  • order (Tuple[str, ...], default: None ) –

    The order in which nodes are initialized.

Returns:

init_record(graph_state: base.GraphState, params: Union[Dict[str, bool], bool] = None, rng: Union[Dict[str, bool], bool] = None, inputs: Union[Dict[str, bool], bool] = None, state: Union[Dict[str, bool], bool] = None, output: Union[Dict[str, bool], bool] = None) -> base.GraphState ¤

Sets the record settings for the nodes in the graph.

Parameters:

  • graph_state (GraphState) –

    The initial graph state from .init().

  • params (Union[Dict[str, bool], bool], default: None ) –

    Whether to record params for each node. Logged once.

  • rng (Union[Dict[str, bool], bool], default: None ) –

    Whether to record rng for each node. Logged each step.

  • inputs (Union[Dict[str, bool], bool], default: None ) –

    Whether to record inputs for each node. Logged each step. Can become very large.

  • state (Union[Dict[str, bool], bool], default: None ) –

    Whether to record state for each node. Logged each step.

  • output (Union[Dict[str, bool], bool], default: None ) –

    Whether to record output for each node. Logged each step.

Returns:

  • GraphState

    The updated graph state with record settings.

run(graph_state: base.GraphState) -> base.GraphState ¤

Executes one step of the graph including the supervisor node and returns the updated graph state.

Different from the .step method, it automatically progresses the graph state post-supervisor execution, suitable for jax.lax.scan or jax.lax.fori_loop operations. This method is different from the gym API, as it uses the .step method of the supervisor node, while the reset and step methods allow the user to override the .step method.

Parameters:

  • graph_state (GraphState) –

    The current graph state, or initial graph state from .init().

Returns:

  • GraphState

    Updated graph state. It returns directly after the supervisor node's step is run.

reset(graph_state: base.GraphState) -> Tuple[base.GraphState, base.StepState] ¤

Prepares the graph for execution by resetting it to a state before the supervisor node's execution.

Returns the graph and step state just before what would be the supervisor's step, mimicking the initial observation return of a gym environment's reset method. The step state can be considered the initial observation of a gym environment.

Parameters:

  • graph_state (GraphState) –

    The graph state from .init().

Returns:

  • Tuple[GraphState, StepState]

    Tuple of the new graph state and the supervisor node's step state before execution of the first step.

step(graph_state: base.GraphState, step_state: base.StepState = None, output: base.Output = None) -> Tuple[base.GraphState, base.StepState] ¤

Executes one step of the graph, optionally overriding the supervisor node's execution.

If step_state and output are provided, they override the supervisor's step, allowing for custom step implementations. Otherwise, the supervisor's step() is executed as usual.

When providing the updated step_state and output, the provided output can be viewed as the action that the agent would take in a gym environment, which is sent to nodes connected to the supervisor node.

Start every episode with a call to reset() using the initial graph state from init(), then call step() repeatedly.

Parameters:

  • graph_state (GraphState) –

    The current graph state.

  • step_state (StepState, default: None ) –

    Custom step state for the supervisor node.

  • output (Output, default: None ) –

    Custom output for the supervisor node.

Returns:

  • Tuple[GraphState, StepState]

    Tuple of the new graph state and the supervisor node's step state before execution of the next step.

rollout(graph_state: base.GraphState, max_steps: int = None, carry_only: bool = True) -> base.GraphState ¤

Executes the graph for a specified number of steps or until a condition is met, starting from a given step and episode.

Utilizes the run method for execution, with an option to return only the final graph state or a sequence of all graph states. By virtue of using the run method, it does not allow for overriding the supervisor node's step method. That is, the supervisor node's step method is used during the rollout.

Note

To record the rollout, use the init_record method on the graph_state before calling this method and set carry_only=True. Then, the record is available in graph_state.aux["record"].

Parameters:

  • graph_state (GraphState) –

    The initial graph state.

  • max_steps (int, default: None ) –

    The maximum steps to execute, if None, runs until a stop condition is met.

  • carry_only (bool, default: True ) –

    If True, returns only the final graph state; otherwise returns all states.

Returns:

  • GraphState

    The final or sequence of graph states post-execution.