Node
rex.node.BaseNode
¤
info: base.NodeInfo
property
¤
Get the node info.
__init__(name: str, rate: float, delay: float = None, delay_dist: Union[base.DelayDistribution, distrax.Distribution] = None, advance: bool = False, scheduling: Scheduling = Scheduling.FREQUENCY, color: str = None, order: int = None)
¤
Base node class. All nodes should inherit from this class.
Basic template for a node class:
class MyNode(BaseNode):
def __init__(self, *args, extra_arg, **kwargs): # Optional
super().__init__(*args, **kwargs)
self.extra_arg = extra_arg
def init_params(self, rng=None, graph_state=None): # Optional
return MyParams(param1=1.0, param2=2.0)
def init_state(self, rng=None, graph_state=None): # Optional
return MyState(state1=1.0, state2=2.0)
def init_output(self, rng=None, graph_state=None): # Required
return MyOutput(output1=1.0, output2=2.0)
def init_delays(self, rng=None, graph_state=None): # Optional
# Set trainable delays to values from params
params = graph_state.params[self.name]
return {"some_node": params.param1} # Connected node name
def startup(self, graph_state, timeout=None): # Optional
# Move the robot to a starting position
return True
def step(self, step_state): # Required
# Unpack step state
params = step_state.params
state = step_state.state
inputs = step_state.inputs
# Calculate output
output = MyOutput(...)
# Update state
new_state = MyState(...)
return step_state.replace(state=new_state), output
def stop(self, timeout=None): # Optional
# Safely the robot at the end of the episode
return True
Parameters:
-
name
(str
) –The name of the node (unique).
-
rate
(float
) –The rate of the node (Hz).
-
delay
(float
, default:None
) –The expected computation delay of the node (s). Used to calculate the phase shift.
-
delay_dist
(Union[DelayDistribution, Distribution]
, default:None
) –The computation delay distribution of the node for simulation.
-
advance
(bool
, default:False
) –Whether the node's step triggers when all inputs are ready, or throttles until the scheduled time.
-
scheduling
(Scheduling
, default:FREQUENCY
) –The scheduling of the node. If
FREQUENCY
, the node is scheduled at a fixed rate, while ignoring any phase shift w.r.t the clock. IfPHASE
, the node steps are scheduled at a fixed rate and phase w.r.t the clock. -
color
(str
, default:None
) –The color of the node (for visualization).
-
order
(int
, default:None
) –The order of the node (for visualization).
connect(output_node: BaseNode, blocking: bool = False, delay: float = None, delay_dist: Union[base.DelayDistribution, distrax.Distribution] = None, window: int = 1, skip: bool = False, jitter: Jitter = Jitter.LATEST, name: str = None)
¤
Connects the node to another node.
Parameters:
-
output_node
(BaseNode
) –The node to connect to.
-
blocking
(bool
, default:False
) –Whether the connection is blocking.
-
delay
(float
, default:None
) –The expected communication delay of the connection.
-
delay_dist
(Union[DelayDistribution, Distribution]
, default:None
) –The communication delay distribution of the connection for simulation.
-
window
(int
, default:1
) –The window size of the connection. It determines how many output messages are used as input to the
.step()
function. -
skip
(bool
, default:False
) –Whether to skip the connection. It resolves cyclic dependencies, by skipping the output if it arrives at the same time as the start of the
.step()
function (i.e.step_state.ts
). -
jitter
(Jitter
, default:LATEST
) –How to deal with jitter of the connection. If
LATEST
, the latest messages are used. IfBUFFER
, the messages are buffered and used in accordance with the expected delay. -
name
(str
, default:None
) –A shadow name for the connected node. If
None
, the name of the output node is used.
init_params(rng: jax.Array = None, graph_state: base.GraphState = None) -> base.Params
¤
Init params of the node.
The params are composed of values that remain constant during an episode (e.g. network weights).
At this point, the graph state may contain the params of other nodes required to get the default params. The order of node initialization can be specified in Graph.init(... order=[node1, node2, ...]).
Parameters:
-
rng
(Array
, default:None
) –Random number generator.
-
graph_state
(GraphState
, default:None
) –The graph state that may be used to get the default params.
Returns:
-
Params
–The default params of the node.
init_state(rng: jax.Array = None, graph_state: base.GraphState = None) -> base.State
¤
Init state of the node.
The state is composed of values that are updated during the episode in the .step()
function (e.g. position, velocity).
At this point, the params of all nodes are already initialized and present in the graph state (if specified).
Moreover, the state of other nodes required to get the default state may also be present in the graph state.
The order of node initialization can be specified in Graph.init(... order=[node1, node2, ...])
.
Parameters:
-
rng
(Array
, default:None
) –Random number generator.
-
graph_state
(GraphState
, default:None
) –The graph state that may be used to get the default state.
Returns:
-
State
–The default state of the node.
init_inputs(rng: jax.Array = None, graph_state: base.GraphState = None) -> FrozenDict[str, base.InputState]
¤
Initialize default inputs for the node.
Fills input buffers with default outputs from connected nodes. Used during the initial steps of an episode when input buffers are not yet filled.
Parameters:
-
rng
(Array
, default:None
) –Random number generator.
-
graph_state
(GraphState
, default:None
) –The graph state that may be used to get the default inputs.
Returns:
-
FrozenDict[str, InputState]
–The default inputs of the node.
init_delays(rng: jax.Array = None, graph_state: base.GraphState = None) -> Dict[str, Union[float, jax.typing.ArrayLike]]
¤
Initialize trainable communication delays.
Note
These delays include only trainable connections. To make a delay trainable, replace the parameters in the delay distribution with trainable parameters.
A rough template for the init_delays function is as follows:
def init_delays(self, rng=None, graph_state=None):
# Assumes graph_state contains the params of the node
params = graph_state.params[self.name]
trainable_delays = {"world": params.delay_param}
return trainable_delays
Parameters:
-
rng
(Array
, default:None
) –Random number generator.
-
graph_state
(GraphState
, default:None
) –The graph state that may be used to get the default output.
Returns:
-
Dict[str, Union[float, ArrayLike]]
–Trainable delays. Can be an incomplete dictionary.
-
Dict[str, Union[float, ArrayLike]]
–Entries for non-trainable delays or non-existent connections are ignored.
init_step_state(rng: jax.Array = None, graph_state: base.GraphState = None) -> base.StepState
¤
Initializes the step state of the node, which is used to run the seq
'th step of the node at time ts
.
BaseNode.init_params
BaseNode.init_state
BaseNode.init_inputs
usingBaseNode.init_output
of connected nodes (to fill the input buffers)
Note
If a node's initialization depends on the params, state, or inputs of other nodes this may fail. In such cases, the user can provide a graph state with the necessary information to get the default step state.
Parameters:
-
rng
(Array
, default:None
) –Random number generator.
-
graph_state
(GraphState
, default:None
) –The graph state that may be used to get the default step state.
Returns:
-
StepState
–The default step state of the node.
startup(graph_state: base.GraphState, timeout: float = None) -> bool
¤
Initializes the node to the state specified by graph_state
.
This method is called just before an episode starts.
It can be used to move a real robot to a starting position as specified by the graph_state
.
Note
Only called when running asynchronously.
Parameters:
-
graph_state
(GraphState
) –The graph state.
-
timeout
(float
, default:None
) –The timeout of the startup.
Returns:
-
bool
–Whether the node has started successfully.
stop(timeout: float = None) -> bool
¤
Stopping routine that is called after the episode is done.
Note
Only called when running asynchronously.
Warning
It may happen that stop is already called before the final .step
call of an episode returns,
which may cause unsafe behavior when the final step undoes the work of the .stop method.
This should be handled by the user. For example, by stopping "longer" before returning here.
Parameters:
-
timeout
(float
, default:None
) –The timeout of the stop.
Returns:
-
bool
–Whether the node has stopped successfully.
step(step_state: base.StepState) -> Tuple[base.StepState, base.Output]
¤
Execute the node for the seq
-th time step at time ts
.
This function updates the node's state and generates an output, which is sent to connected nodes. It is called at the node's rate.
Users are expected to update the state (and rng if used), but not the seq and ts, as they are automatically updated.
Wrapping side-effecting code
Side-effecting code should be wrapped to ensure execution on the host machine when using jax.jit
.
See here for more info.
A rough template for the step function is as follows:
def step(step_state: base.StepState) -> Tuple[base.StepState, base.Output]:
# Per input with `input_name`, the following information is available:
step_state.inputs[input_name][window_index].data # A window_index of -1 leads to the most recent message.
step_state.inputs[input_name][window_index].seq # The sequence number of the message.
step_state.inputs[input_name][window_index].ts_sent # The time the message was sent.
step_state.inputs[input_name][window_index].ts_recv # The time the message was received.
# The following information is available for the node:
step_state.params # The parameters of the node.
step_state.state # The state of the node.
step_state.eps # The episode number.
step_state.seq # The sequence number.
step_state.ts # The time of the step within the episode.
step_state.rng # The random number generator.
# Calculate output and updated state
new_rng, rng_step = jax.random.split(step_state.rng)
output = ...
new_state = ...
# Update the state of the node
new_ss = step_state.replace(rng=new_rng, state=new_state) #
return new_ss, output
Parameters:
-
step_state
(StepState
) –The step state of the node.
Returns:
-
Tuple[StepState, Output]
–The updated step state and the output of the node.
now() -> float
¤
Get the passed time since start of episode according to the simulated and wall clock.
Returns:
-
float
–Time since start of episode. Only returns > 0 timestamps if running asynchronously.
set_delay(delay_dist: Union[base.DelayDistribution, distrax.Distribution] = None, delay: float = None)
¤
Set the delay distribution and delay for the computation delay of the node.
Parameters:
-
delay_dist
(Union[DelayDistribution, Distribution]
, default:None
) –The delay distribution to simulate.
-
delay
(float
, default:None
) –The delay to take into account for the phase shift.
from_info(info: base.NodeInfo, **kwargs: Dict[str, Any])
classmethod
¤
Re-instantiates a Node from a NodeInfo object.
Don't forget to call connect_from_info()
.
Make sure to call connect_from_info() on the resulting subclass object to restore the connections.
Note
This method attempts to restore the subclass object from the BaseNode object. Hence, it requires any additional arguments to be passed as keyword arguments. Moreover, the signature of the subclass must be the same as the BaseNode, except for the additional args and *kwargs.
Parameters:
-
info
(NodeInfo
) –Node info object.
-
**kwargs
(Dict[str, Any]
, default:{}
) –Additional keyword arguments for the subclass.
rex.node.BaseWorld
¤
Bases: BaseNode
__init__(name: str, rate: float, color: str = None, order: int = None, **kwargs)
¤
Base node class for world (i.e. simulator) nodes.
A convenience class that pre-sets parameters for nodes that simulate real-world processes. That is, nodes that simulate continuous processes in a discrete manner.
- The delay distribution is set to the time step of the node (~1/rate). It's currently set slightly below the time step to ensure numerical stability, as else we may unavoidably introduce more delay.
- The advance is set to False, as the world node should adhere to the rate of the node.
- The scheduling is set to FREQUENCY, as the world node should adhere to the rate of the node.
Parameters:
-
name
(str
) –The name of the node (unique).
-
rate
(float
) –The rate of the node (Hz).
-
color
(str
, default:None
) –The color of the node (for visualization).
-
order
(int
, default:None
) –The order of the node (for visualization).