Defining Nodes in rex (Robotic Environments with jaX) ![](https://colab.research.google.com/assets/colab-badge.svg)
¤
This notebook offers an introductory tutorial for rex (Robotic Environments with jaX), a JAX-based framework for creating graph-based environments tailored for sim2real robotics.
In this tutorial, we will guide you through the process of defining nodes, which are the fundamental building blocks for constructing graph-based simulations and real-world systems within rex. Specifically, we will demonstrate how to define the nodes used in the sim2real.ipynb notebook.
# @title Install Necessary Libraries
# @markdown This cell installs the required libraries for the project.
# @markdown If you are running this notebook in Google Colab, most libraries should already be installed.
try:
import rex # noqa: F401
print("Rex already installed")
except ImportError:
print(
"Installing rex via `pip install rex-lib[examples]`. "
"If you are running this in a Colab notebook, you can ignore this message."
)
!pip install rex-lib[examples]
Introduction to Nodes in Rex¤
In Rex, a node represents a fundamental computational unit within a graph-based system. Nodes encapsulate specific functionality and interact by passing data through connections, forming a network that can model complex systems. This tutorial introduces how to define nodes, specify their properties like rates and delays, and manage their interactions within a graph.
Defining Nodes¤
Nodes are defined by creating subclasses of the BaseNode
class. This base class provides a standardized API and essential functionality that all nodes inherit. When defining a node, you can specify several parameters directly in the __init__
method:
name
: A unique identifier for the node.rate
: The frequency at which the node'sstep
method is called (in Hz).delay
(optional): The expected computation delay of the node (in seconds).delay_dist
: A distribution representing variability in the node's computation delay, useful for simulations.advance
: IfTrue
, the node'sstep
method triggers when all inputs are ready; ifFalse
, it throttles until the scheduled time.scheduling
: Determines how the node's execution is scheduled. Options includeScheduling.FREQUENCY
andScheduling.PHASE
.color
: Used for visualization purposes.order
: Determines the node's order in visualizations.
Here's a basic example of a node definition:
class MyNode(BaseNode):
def __init__(
self,
name: str,
rate: float,
delay: float = None, # Expected computation delay (used for phase-shifting)
delay_dist: Union[DelayDistribution, distrax.Distribution] = None, # Sim. computation delay
advance: bool = False,
scheduling: Scheduling = Scheduling.FREQUENCY,
color: str = None,
order: int = None
):
super().__init__(name, rate, delay, delay_dist, advance, scheduling, color, order)
# Additional initialization if needed
def init_params(self, rng: jax.Array = None, graph_state: GraphState = None):
# Initialize parameters
return MyParams()
def init_state(self, rng: jax.Array = None, graph_state: GraphState = None):
# Initialize state
return MyState()
def init_output(self, rng: jax.Array = None, graph_state: GraphState = None):
# Initialize default output
return MyOutput()
def step(self, step_state: StepState):
# Node's computation logic
new_state = ...
output = ...
return step_state.replace(state=new_state), output
Connecting Nodes¤
Nodes interact by passing outputs from one node to the inputs of another. This is achieved through the connect
method, which establishes a connection between two nodes.
Connection API¤
When connecting nodes, you can specify several parameters that control the nature of the connection:
output_node
: node whose output will be connected as an input.blocking
:True
, the receiving node waits for the input before proceeding. This can create dependencies between nodes.delay
: An additional delay introduced in the connection, which can control the phase shift between nodes.delay_dist
: Used in simulation to model communication delays between nodes.window
: Determines how many past messages are stored and accessible in the input buffer.skip
: IfTrue
, the connection is skipped when messages arrive simultaneously, helping resolve cyclic dependencies.jitter
: Controls how to handle irregularities in message timing (e.g.,Jitter.LATEST
uses the most recent message).name
: A shadow name for the input; defaults to the output node's name.
Including delay_dist
in Connection¤
The delay_dist
parameter allows you to specify a distribution that models the variability in communication delay between nodes. This is particularly useful in simulations where network latency or message passing delays are significant.
Resolving Cyclic Dependencies with skip
¤
In graphs where nodes depend on each other's outputs (creating a cycle), the skip
parameter can be used to resolve the dependency. By setting skip=True
on a connection, you instruct the receiving node to proceed without waiting for the current message if it arrives simultaneously. This breaks the cycle and allows the system to function.
Example Connection¤
node_a.connect(
output_node=node_b,
blocking=True,
delay=0.01, # Expected communication delay (used for phase-shifting)
delay_dist=distrax.Normal(loc=0.01, scale=0.005), # Sim. communication delay
window=5,
skip=False,
jitter=Jitter.LATEST,
name="input_from_b"
)
In this example, node_a
connects to node_b
with a blocking connection, an added delay of 0.01 seconds, and a delay distribution for simulation purposes. The window
size is set to 5, meaning the last five messages are stored. The skip
parameter is False
, so the node will wait for the input.
Node Data Structure¤
Nodes manage four main types of data (defined as pytrees), typically defined using immutable dataclasses for efficiency and safety:
- Parameters: Static configurations that usually remain constant during execution.
- State: Dynamic data that evolves over time with each
step
. - Outputs: Data produced by a node's
step
method and sent to connected nodes. - Inputs: Buffers that hold incoming data from other nodes, respecting the specified window size.
Immutable Dataclasses¤
Using immutable dataclasses (e.g., via @struct.dataclass
from Flax) ensures that the data structures are compatible with JAX's JIT compilation and functional programming paradigms. Additionally, dataclasses allow you to define specific methods related to the data structure, providing encapsulation and clarity.
@struct.dataclass
class MyParams:
some_parameter: float
def adjust_parameter(self, factor: float):
return self.replace(some_parameter=self.some_parameter * factor)
@struct.dataclass
class MyState:
some_state_variable: jax.Array
def update_state(self, delta: jax.Array):
return self.replace(some_state_variable=self.some_state_variable + delta)
@struct.dataclass
class MyOutput:
some_output_data: jax.Array
In this example, MyParams
and MyState
include methods to adjust parameters and update state, respectively. This encapsulation enhances code organization and readability.
Initialization¤
Node data is initialized using specific methods that you should override:
init_params
: Initializes the node's parameters.init_state
: Initializes the node's state.init_output
: Provides a default output, useful for initializing input buffers in connected nodes.
These methods are typically called during the graph's initialization phase using graph.init()
.
The step
Method in Detail¤
The step
method defines how a node processes inputs and updates its state at each timestep. It receives a StepState
object with all necessary information.
StepState
Attributes¤
rng
: Random number generator (updated if used).state
: Node's current state.params
: Static parameters influencing behavior.inputs
: Dictionary ofInputState
instances (keyed by input names).eps
: Episode number relates to the current computation graph used for simulation (unrelated to RL episode number).seq
: Current step number (auto-increments with each step).ts
: Timestamp at the start of the step.
Accessing Inputs¤
Each InputState
in step_state.inputs
contains:
data
: Messages from the connected node.seq
: Sequence numbers of the received messages.ts_sent
: Timestamps when messages were sent.ts_recv
: Timestamps when messages were received.
For example, accessing the most recent message:
latest_sensor_input = step_state.inputs['sensor'][-1].data
Implementing the step
Method¤
The typical steps to implement the step
method can be condensed into the following block:
def step(self, step_state: StepState):
# Unpack StepState
rng, state, params, inputs = step_state.rng, step_state.state, step_state.params, step_state.inputs
# Access latest input
control_signal = inputs['controller'][-1].data
# Update state
new_state_variable = state.some_state_variable + control_signal * params.gain
new_state = state.replace(some_state_variable=new_state_variable)
# Produce output
output = MyOutput(some_output_data=new_state_variable)
# Update RNG if randomness is involved
rng, _ = jax.random.split(rng)
# Return updated StepState and output
return step_state.replace(state=new_state, rng=rng), output
Working with Time and Sequence¤
Use eps
, ts
and seq
for time-dependent logic:
if step_state.ts > params.activation_time:
# Perform time-based logic
pass
Handling Input Windows¤
If the input window size is greater than 1, you can access past messages:
recent_sensor_data = inputs['sensor_input'][-3:].data
JIT Compilation and Side Effects Handling with External Callbacks¤
Rex advocates for JIT-compiling the step
method of each node to enhance performance. However, interfacing with real hardware often involves side effects that JAX's JIT compilation doesn't handle natively.
To include side-effecting code (e.g., sending commands to actuators, reading sensor data), you must use JAX's external callback mechanism. This involves wrapping side-effecting functions with jax.experimental.io_callback
to ensure compatibility with JIT compilation.
Refer to the JAX documentation on external callbacks for detailed guidance.
def step(self, step_state: StepState):
# Compute outputs
output = ...
# Side-effecting function
def _apply_action(action):
# Code that interacts with hardware
return np.array(1.0) # Dummy return value
# Wrap side-effecting code
_ = jax.experimental.io_callback(
_apply_action,
result_shape=jnp.array(1.0),
arg=output.some_output_data
)
# Update state and return
return step_state, output
Real-World Nodes and Lifecycle Methods¤
When nodes interface with real hardware or external systems, additional lifecycle management is necessary. The BaseNode
API accommodates this through:
startup
: Called before an episode starts, allowing the node to prepare (e.g., initialize hardware).stop
: Called after an episode ends, enabling the node to clean up resources or safely shut down hardware.
class RealWorldNode(BaseNode):
def __init__(
self,
name: str,
rate: float,
delay: float = None,
delay_dist: Union[DelayDistribution, distrax.Distribution] = None,
advance: bool = False,
scheduling: Scheduling = Scheduling.FREQUENCY,
color: str = None,
order: int = None
):
super().__init__(name, rate, delay, delay_dist, advance, scheduling, color, order)
# Additional initialization if needed
def startup(self, graph_state: GraphState, timeout: float = None):
# Initialize hardware connections
return True # Return True if successful
def stop(self, timeout: float = None):
# Safely shut down hardware
return True
Summary¤
By following these guidelines, you can define robust and efficient nodes within the Rex framework. Nodes can be customized extensively through their parameters and state, connected flexibly to form complex graphs, and optimized using JIT compilation. Proper handling of side effects ensures that nodes interfacing with real-world systems remain performant and reliable.
In the following examples, we'll implement specific nodes that illustrate these concepts in practice.
# @title Example: Agent
from typing import Tuple, Union
import jax
from flax import struct
from flax.core import FrozenDict
from jax import numpy as jnp
from rex import base
from rex.base import GraphState, StepState
from rex.node import BaseNode
from rex.ppo import Policy
@struct.dataclass
class AgentOutput(base.Base):
"""Agent's output"""
action: jax.typing.ArrayLike # Torque to apply to the pendulum
@struct.dataclass
class AgentParams(base.Base):
# Policy
policy: Policy
# Observations
num_act: Union[int, jax.typing.ArrayLike] = struct.field(pytree_node=False) # Action history length
num_obs: Union[int, jax.typing.ArrayLike] = struct.field(pytree_node=False) # Observation history length
# Action
max_torque: Union[float, jax.typing.ArrayLike]
# Initial state
init_method: str = struct.field(pytree_node=False) # "random", "parametrized"
parametrized: jax.typing.ArrayLike
max_th: Union[float, jax.typing.ArrayLike]
max_thdot: Union[float, jax.typing.ArrayLike]
# Train
gamma: Union[float, jax.typing.ArrayLike]
tmax: Union[float, jax.typing.ArrayLike]
@staticmethod
def process_inputs(inputs: FrozenDict[str, base.InputState]) -> jax.Array:
th, thdot = inputs["sensor"][-1].data.th, inputs["sensor"][-1].data.thdot
obs = jnp.array([jnp.cos(th), jnp.sin(th), thdot])
return obs
@staticmethod
def get_observation(step_state: StepState) -> jax.Array:
# Unpack StepState
inputs, state = step_state.inputs, step_state.state
# Convert inputs to single observation
single_obs = AgentParams.process_inputs(inputs)
# Concatenate with previous observations
obs = jnp.concatenate([single_obs, state.history_obs.flatten(), state.history_act.flatten()])
return obs
@staticmethod
def update_state(step_state: StepState, action: jax.Array) -> "AgentState":
# Unpack StepState
state, params, inputs = step_state.state, step_state.params, step_state.inputs
# Convert inputs to observation
single_obs = AgentParams.process_inputs(inputs)
# Update obs history
if params.num_obs > 0:
history_obs = jnp.roll(state.history_obs, shift=1, axis=0)
history_obs = history_obs.at[0].set(single_obs)
else:
history_obs = state.history_obs
# Update act history
if params.num_act > 0:
history_act = jnp.roll(state.history_act, shift=1, axis=0)
history_act = history_act.at[0].set(action)
else:
history_act = state.history_act
new_state = state.replace(history_obs=history_obs, history_act=history_act)
return new_state
@staticmethod
def to_output(action: jax.Array) -> AgentOutput:
return AgentOutput(action=action)
@struct.dataclass
class AgentState(base.Base):
history_act: jax.typing.ArrayLike # History of actions
history_obs: jax.typing.ArrayLike # History of observations
init_th: Union[float, jax.typing.ArrayLike] # Initial angle of the pendulum
init_thdot: Union[float, jax.typing.ArrayLike] # Initial angular velocity of the pendulum
class Agent(BaseNode):
def init_params(self, rng: jax.Array = None, graph_state: GraphState = None) -> AgentParams:
return AgentParams(
policy=None, # Policy must be set by the user
num_act=4, # Number of actions to keep in history
num_obs=4, # Number of observations to keep in history
max_torque=2.0, # Maximum torque that can be applied to the pendulum
init_method="parametrized", # "random" or "parametrized"
parametrized=jnp.array([jnp.pi, 0.0]), # [th, thdot]
max_th=jnp.pi, # Maximum initial angle of the pendulum
max_thdot=9.0, # Maximum initial angular velocity of the pendulum
gamma=0.99, # Discount factor (used during training)
tmax=3.0, # Maximum time for an episode (used during training)
)
def init_state(self, rng: jax.Array = None, graph_state: GraphState = None) -> AgentState:
graph_state = graph_state or base.GraphState()
params = graph_state.params.get(self.name, self.init_params(rng, graph_state))
history_act = jnp.zeros((params.num_act, 1), dtype=jnp.float32) # [torque]
history_obs = jnp.zeros((params.num_obs, 3), dtype=jnp.float32) # [cos(th), sin(th), thdot]
# Set the initial state of the pendulum
if params.init_method == "parametrized":
init_th, init_thdot = params.parametrized
elif params.init_method == "random":
rng = rng if rng is not None else jax.random.PRNGKey(0)
rngs = jax.random.split(rng, num=2)
init_th = jax.random.uniform(rngs[0], shape=(), minval=-params.max_th, maxval=params.max_th)
init_thdot = jax.random.uniform(rngs[1], shape=(), minval=-params.max_thdot, maxval=params.max_thdot)
else:
raise ValueError(f"Invalid init_method: {params.init_method}")
return AgentState(history_act=history_act, history_obs=history_obs, init_th=init_th, init_thdot=init_thdot)
def init_output(self, rng: jax.Array = None, graph_state: GraphState = None) -> AgentOutput:
"""Default output of the node."""
rng = jax.random.PRNGKey(0) if rng is None else rng
graph_state = graph_state or base.GraphState()
params = graph_state.params.get(self.name, self.init_params(rng, graph_state))
action = jax.random.uniform(rng, shape=(1,), minval=-params.max_torque, maxval=params.max_torque)
return AgentOutput(action=action)
def step(self, step_state: StepState) -> Tuple[StepState, AgentOutput]:
"""Step the node."""
# Unpack StepState
rng, params = step_state.rng, step_state.params
# Prepare output
rng, rng_net = jax.random.split(rng)
if params.policy is not None: # Use policy to get action
obs = AgentParams.get_observation(step_state)
action = params.policy.get_action(obs, rng=None) # Supply rng for stochastic policies
else: # Random action if no policy is set
action = jax.random.uniform(rng_net, shape=(1,), minval=-params.max_torque, maxval=params.max_torque)
output = AgentParams.to_output(action) # Convert action to output message
# Update step_state (observation and action history)
new_state = params.update_state(step_state, action) # Update state
new_step_state = step_state.replace(rng=rng, state=new_state) # Update step_state
return new_step_state, output
# @title Example: Actuator
from typing import Tuple, Union
import jax
import numpy as onp
from flax import struct
from rex import base
from rex.base import GraphState, StepState
from rex.jax_utils import tree_dynamic_slice
from rex.node import BaseNode
@struct.dataclass
class ActuatorOutput(base.Base):
"""Pendulum actuator output"""
action: jax.typing.ArrayLike # Torque to apply to the pendulum
@struct.dataclass
class ActuatorParams(base.Base):
"""Pendulum actuator param definition"""
actuator_delay: Union[float, jax.typing.ArrayLike]
class Actuator(BaseNode):
"""This is a simple actuator node definition that could interface a real actuator.
When interfacing real hardware, you would send the action to real hardware in the .step method.
Optionally, you could also specify a startup routine that is called right before an episode starts.
Finally, a stop routine is called after the episode is done.
"""
def __init__(self, *args, **kwargs):
"""No special initialization needed."""
super().__init__(*args, **kwargs)
def init_params(self, rng: jax.Array = None, graph_state: GraphState = None) -> ActuatorParams:
"""Default params of the node."""
actuator_delay = 0.05
return ActuatorParams(actuator_delay=actuator_delay)
def init_output(self, rng: jax.Array = None, graph_state: GraphState = None) -> ActuatorOutput:
"""Default output of the node."""
return ActuatorOutput(action=jnp.array([0.0], dtype=jnp.float32))
def startup(self, graph_state: base.GraphState, timeout: float = None) -> bool:
"""Starts the node in the state specified by graph_state.
This method is called right before an episode starts.
It can be used to move (a real) robot to a starting position as specified by the graph_state.
Not used when running in compiled mode.
:param graph_state: The graph state.
:param timeout: The timeout of the startup.
:return: Whether the node has started successfully.
"""
# Move robot to starting position specified by graph_state (e.g. graph_state.state["agent"].init_th)
return True # Not doing anything here
def step(self, step_state: StepState) -> Tuple[StepState, ActuatorOutput]:
"""If we were to control a real robot, you would send the action to the robot here."""
# Prepare output
output = step_state.inputs["agent"][-1].data
output = ActuatorOutput(action=output.action)
def _apply_action(action):
"""
Not really doing anything here, just a dummy implementation.
Include some side-effecting code here (e.g. sending the action to a real robot).
The .step method may be jit-compiled, it is important to wrap any side-effecting code in a host_callback.
See the jax documentation for more information on how to do this:
https://jax.readthedocs.io/en/latest/external-callbacks.html
"""
# print(f"Applying action: {action}") # Apply action to the robot
return onp.array(1.0) # Must match dtype and shape of return_shape
# Apply action to the robot
return_shape = jnp.array(1.0) # Must match dtype and shape of return_shape
_ = jax.experimental.io_callback(_apply_action, return_shape, output)
# Update state
new_step_state = step_state
return new_step_state, output
def stop(self, timeout: float = None) -> bool:
"""Stopping routine that is called after the episode is done.
**IMPORTANT** It may happen that stop is called *before* the final .step call of an episode returns,
which may cause unsafe behavior when the final step undoes the work of the .stop method.
This should be handled by the user. For example, by stopping "longer" before returning here.
Only ran when running asynchronously.
:param timeout: The timeout of the stop
:return: Whether the node has stopped successfully.
"""
# Stop the robot (e.g. set the torque to 0)
return True
class SimActuator(BaseNode):
"""This is a simple simulated actuator node definition that can either
1. Feedthrough the agent's action (for normal operation, e.g., training).
Optionally, you could include some noise or other modifications to the action.
2. Reapply the recorded actuator outputs for system identification if available.
"""
def __init__(self, *args, outputs: ActuatorOutput = None, **kwargs):
"""Initialize Actuator for system identification.
Here, we will reapply the recorded actuator outputs for system identification if available.
:param outputs: Recorded actuator Outputs to be used for system identification.
"""
super().__init__(*args, **kwargs)
self._outputs = outputs
def init_params(self, rng: jax.Array = None, graph_state: GraphState = None) -> ActuatorParams:
"""Default params of the node."""
actuator_delay = 0.05
return ActuatorParams(actuator_delay=actuator_delay)
def init_output(self, rng: jax.Array = None, graph_state: GraphState = None) -> ActuatorOutput:
"""Default output of the node."""
return ActuatorOutput(action=jnp.array([0.0], dtype=jnp.float32))
def step(self, step_state: StepState) -> Tuple[StepState, ActuatorOutput]:
# Get action from dataset if available, else use the one provided by the agent
if self._outputs is not None: # Use the recorded action (for system identification)
output = tree_dynamic_slice(self._outputs, jnp.array([step_state.eps, step_state.seq]))
output = jax.tree_util.tree_map(lambda _o: _o[0, 0], output)
else: # Feedthrough the agent's action (for normal operation, e.g., training)
output = step_state.inputs["agent"][-1].data
output = ActuatorOutput(action=output.action)
new_step_state = step_state
return new_step_state, output
# @title Example: Sensor
from typing import Dict, Tuple, Union
import jax
from flax import struct
from rex import base
from rex.base import GraphState, StepState
from rex.node import BaseNode
@struct.dataclass
class SensorOutput(base.Base):
"""Output message definition of the sensor node."""
th: Union[float, jax.typing.ArrayLike]
thdot: Union[float, jax.typing.ArrayLike]
@struct.dataclass
class SensorParams(base.Base):
"""
Other than the sensor delay, we don't have any other parameters.
You could add more parameters here if needed, such as noise levels etc.
"""
sensor_delay: Union[float, jax.typing.ArrayLike]
@struct.dataclass
class SensorState:
"""We use this state to record the reconstruction loss."""
loss_th: Union[float, jax.typing.ArrayLike]
loss_thdot: Union[float, jax.typing.ArrayLike]
class Sensor(BaseNode):
"""This is a simple sensor node definition that interfaces a real sensor.
When interfacing real hardware, you would grab the sensor measurement in the .step method.
Optionally, you could also specify a startup routine that is called right before an episode starts.
Finally, a stop routine is called after the episode is done.
"""
def __init__(self, *args, **kwargs):
"""No special initialization needed."""
super().__init__(*args, **kwargs)
def init_params(self, rng: jax.Array = None, graph_state: GraphState = None) -> SensorParams:
"""Default params of the node."""
sensor_delay = 0.05
return SensorParams(sensor_delay=sensor_delay)
def init_output(self, rng: jax.Array = None, graph_state: GraphState = None) -> SensorOutput:
"""Default output of the node."""
# Randomly define some initial sensor values
th = jnp.pi
thdot = 0.0
return SensorOutput(th=th, thdot=thdot)
def startup(self, graph_state: base.GraphState, timeout: float = None) -> bool:
"""Starts the node in the state specified by graph_state.
This method is called right before an episode starts.
It can be used to move (a real) robot to a starting position as specified by the graph_state.
Not used when running in compiled mode.
:param graph_state: The graph state.
:param timeout: The timeout of the startup.
:return: Whether the node has started successfully.
"""
return True # Not doing anything here
def step(self, step_state: StepState) -> Tuple[StepState, SensorOutput]:
"""If we were to interface a real hardware, you would grab the sensor measurement here."""
"""
As the .step method may be jit-compiled, it is important to wrap any side-effecting code in a host_callback.
See the jax documentation for more information on how to do this:
https://jax.readthedocs.io/en/latest/external-callbacks.html
"""
world = step_state.inputs["world"][-1].data
def _grab_measurement():
"""
Not really doing anything here, just a dummy implementation.
Include some side-effecting code here (e.g. grabbing measurement from sensor).
The .step method may be jit-compiled, it is important to wrap any side-effecting code in a host_callback.
See the jax documentation for more information on how to do this:
https://jax.readthedocs.io/en/latest/external-callbacks.html
"""
# print("Grabbing sensor measurement")
sensor_msg = onp.array(1.0) # Dummy sensor measurement (not actually used)
return sensor_msg # Must match dtype and shape of return_shape
# Grab sensor measurement
return_shape = jnp.array(1.0) # Must match dtype and shape of return_shape
_ = jax.experimental.io_callback(_grab_measurement, return_shape)
# Prepare output
output = SensorOutput(th=world.th, thdot=world.thdot)
# Update state (NOOP)
new_step_state = step_state
return new_step_state, output
def stop(self, timeout: float = None) -> bool:
"""Stopping routine that is called after the episode is done.
**IMPORTANT** It may happen that stop is called *before* the final .step call of an episode returns,
which may cause unsafe behavior when the final step undoes the work of the .stop method.
This should be handled by the user. For example, by stopping "longer" before returning here.
Only ran when running asynchronously.
:param timeout: The timeout of the stop
:return: Whether the node has stopped successfully.
"""
return True # Not doing anything here
class SimSensor(BaseNode):
"""This is a simple simulated sensor node definition that can either
1. Convert the world state into a realistic sensor measurement (for normal operation, e.g., training).
Optionally, you could include some noise or other modifications to the sensor measurement.
2. Calculate a reconstruction loss based on the sensor measurement and the recorded sensor outputs.
By calculating and aggregating the reconstruction loss here, we take time-scale differences and delays into account.
"""
def __init__(self, *args, outputs: SensorOutput = None, **kwargs):
"""Initialize a simulated sensor for system identification.
If outputs are provided, we will calculate the reconstruction loss based on the recorded sensor outputs.
:param outputs: Recorded sensor Outputs to be used for system identification.
"""
super().__init__(*args, **kwargs)
self._outputs = outputs
def init_params(self, rng: jax.Array = None, graph_state: GraphState = None) -> SensorParams:
"""Default params of the node."""
sensor_delay = 0.05
return SensorParams(sensor_delay=sensor_delay)
def init_state(self, rng: jax.Array = None, graph_state: GraphState = None) -> SensorState:
"""Default state of the node."""
return SensorState(loss_th=0.0, loss_thdot=0.0) # Initialize reconstruction loss to zero at the start of the episode
def init_output(self, rng: jax.Array = None, graph_state: GraphState = None) -> SensorOutput:
"""Default output of the node."""
# Randomly define some initial sensor values
th = jnp.pi
thdot = 0.0
return SensorOutput(th=th, thdot=thdot) # Fix the initial sensor values
def init_delays(
self, rng: jax.Array = None, graph_state: base.GraphState = None
) -> Dict[str, Union[float, jax.typing.ArrayLike]]:
"""Initialize trainable communication delays.
**Note** These only include trainable delays that were specified while connecting the nodes.
:param rng: Random number generator.
:param graph_state: The graph state that may be used to get the default output.
:return: Trainable delays (e.g., {input_name: delay}). Can be an incomplete dictionary.
Entries for non-trainable delays or non-existent connections are ignored.
"""
graph_state = graph_state or GraphState()
params = graph_state.params.get(self.name, self.init_params(rng, graph_state))
delays = {"world": params.sensor_delay}
return delays
def step(self, step_state: StepState) -> Tuple[StepState, SensorOutput]:
# Determine output
data = step_state.inputs["world"][-1].data
output = SensorOutput(th=data.th, thdot=data.thdot)
# Calculate loss
if self._outputs is not None: # Calculate reconstruction loss and aggregate in state
output_rec = tree_dynamic_slice(self._outputs, jnp.array([step_state.eps, step_state.seq]))
output_rec = jax.tree_util.tree_map(lambda _o: _o[0, 0], output_rec)
th_rec, thdot_rec = output_rec.th, output_rec.thdot
state = step_state.state
loss_th = state.loss_th + (jnp.sin(output.th) - jnp.sin(th_rec)) ** 2 + (jnp.cos(output.th) - jnp.cos(th_rec)) ** 2
loss_thdot = state.loss_thdot + (output.thdot - thdot_rec) ** 2
new_state = state.replace(loss_th=loss_th, loss_thdot=loss_thdot)
else: # NOOP
new_state = step_state.state
# Update step_state
new_step_state = step_state.replace(state=new_state)
return new_step_state, output
# @title Example: ODE simulation node
from math import ceil
from typing import Dict, Tuple, Union
import jax
from flax import struct
from rex import base
from rex.base import GraphState, StepState
from rex.node import BaseWorld
@struct.dataclass
class OdeParams(base.Base):
"""Pendulum ode param definition"""
max_speed: Union[float, jax.typing.ArrayLike]
J: Union[float, jax.typing.ArrayLike]
mass: Union[float, jax.typing.ArrayLike]
length: Union[float, jax.typing.ArrayLike]
b: Union[float, jax.typing.ArrayLike]
K: Union[float, jax.typing.ArrayLike]
R: Union[float, jax.typing.ArrayLike]
c: Union[float, jax.typing.ArrayLike]
dt_substeps_min: Union[float, jax.typing.ArrayLike] = struct.field(pytree_node=False)
dt: Union[float, jax.typing.ArrayLike] = struct.field(pytree_node=False)
@property
def substeps(self) -> int:
substeps = ceil(self.dt / self.dt_substeps_min)
return int(substeps)
@property
def dt_substeps(self) -> float:
substeps = self.substeps
dt_substeps = self.dt / substeps
return dt_substeps
def step(
self, substeps: int, dt_substeps: jax.typing.ArrayLike, x: "OdeState", us: jax.typing.ArrayLike
) -> Tuple["OdeState", "OdeState"]:
"""Step the pendulum ode."""
def _scan_fn(_x, _u):
next_x = self._runge_kutta4(dt_substeps, _x, _u)
# Clip velocity
clip_thdot = jnp.clip(next_x.thdot, -self.max_speed, self.max_speed)
next_x = next_x.replace(thdot=clip_thdot)
return next_x, next_x
x_final, x_substeps = jax.lax.scan(_scan_fn, x, us, length=substeps)
return x_final, x_substeps
def _runge_kutta4(self, dt: jax.typing.ArrayLike, x: "OdeState", u: jax.typing.ArrayLike) -> "OdeState":
k1 = self._ode(x, u)
k2 = self._ode(x + k1 * dt * 0.5, u)
k3 = self._ode(x + k2 * dt * 0.5, u)
k4 = self._ode(x + k3 * dt, u)
return x + (k1 + k2 * 2 + k3 * 2 + k4) * (dt / 6)
def _ode(self, x: "OdeState", u: jax.typing.ArrayLike) -> "OdeState":
"""dx function for the pendulum ode"""
# Downward := [pi, 0], Upward := [0, 0]
g, J, m, l, b, K, R, c = 9.81, self.J, self.mass, self.length, self.b, self.K, self.R, self.c # noqa: E741
th, thdot = x.th, x.thdot
activation = jnp.sign(thdot)
ddx = (u * K / R + m * g * l * jnp.sin(th) - b * thdot - thdot * K * K / R - c * activation) / J
return OdeState(th=thdot, thdot=ddx, loss_task=0.0) # No derivative for loss_task
@struct.dataclass
class OdeState(base.Base):
"""Pendulum state definition"""
loss_task: Union[float, jax.typing.ArrayLike]
th: Union[float, jax.typing.ArrayLike]
thdot: Union[float, jax.typing.ArrayLike]
@struct.dataclass
class OdeOutput(base.Base):
"""World output definition"""
th: Union[float, jax.typing.ArrayLike]
thdot: Union[float, jax.typing.ArrayLike]
class OdeWorld(BaseWorld): # We inherit from BaseWorld for convenience, but you can inherit from BaseNode if you want
def init_params(self, rng: jax.Array = None, graph_state: GraphState = None) -> OdeParams:
"""Default params of the node."""
return OdeParams(
max_speed=40.0, # Clip angular velocity to this value
J=0.000159931461600856, # 0.000159931461600856,
mass=0.0508581731919534, # 0.0508581731919534,
length=0.0415233722862552, # 0.0415233722862552,
b=1.43298488e-05, # 1.43298488358436e-05,
K=0.03333912, # 0.0333391179016334,
R=7.73125142, # 7.73125142447252,
c=0.000975041213361349, # 0.000975041213361349,
# Backend parameters
dt_substeps_min=1 / 100, # Minimum substep size for ode integration
dt=1 / self.rate, # Time step per .step() call
)
def init_state(self, rng: jax.Array = None, graph_state: GraphState = None) -> OdeState:
"""Default state of the node."""
graph_state = graph_state or GraphState()
# Try to grab state from graph_state
state = graph_state.state.get("agent", None)
init_th = state.init_th if state is not None else jnp.pi
init_thdot = state.init_thdot if state is not None else 0.0
return OdeState(th=init_th, thdot=init_thdot, loss_task=0.0)
def init_output(self, rng: jax.Array = None, graph_state: GraphState = None) -> OdeOutput:
"""Default output of the node."""
graph_state = graph_state or GraphState()
# Grab output from state
world_state = graph_state.state.get(self.name, self.init_state(rng, graph_state))
return OdeOutput(th=world_state.th, thdot=world_state.thdot)
def init_delays(
self, rng: jax.Array = None, graph_state: base.GraphState = None
) -> Dict[str, Union[float, jax.typing.ArrayLike]]:
graph_state = graph_state or GraphState()
params = graph_state.params.get("actuator")
delays = {}
if hasattr(params, "actuator_delay"):
delays["actuator"] = params.actuator_delay
return delays
def step(self, step_state: StepState) -> Tuple[StepState, OdeOutput]:
"""Step the node."""
# Unpack StepState
_, state, params, inputs = step_state.rng, step_state.state, step_state.params, step_state.inputs
# Apply dynamics
u = inputs["actuator"].data.action[-1][0] # [-1] to get the latest action, [0] reduces the dimension to scalar
us = jnp.array([u] * params.substeps)
new_state = params.step(params.substeps, params.dt_substeps, state, us)[0]
next_th, next_thdot = new_state.th, new_state.thdot
output = OdeOutput(th=next_th, thdot=next_thdot) # Prepare output
# Calculate cost (penalize angle error, angular velocity and input voltage)
norm_next_th = next_th - 2 * jnp.pi * jnp.floor((next_th + jnp.pi) / (2 * jnp.pi))
loss_task = state.loss_task + norm_next_th**2 + 0.1 * (next_thdot / (1 + 10 * abs(norm_next_th))) ** 2 + 0.01 * u**2
# Update state
new_state = new_state.replace(loss_task=loss_task)
new_step_state = step_state.replace(state=new_state)
return new_step_state, output
# @title Example: Brax simulation node
from typing import Tuple, Union
import jax
from flax import struct
from rex import base
from rex.base import GraphState, StepState
from rex.node import BaseWorld
try:
from brax.generalized import pipeline as gen_pipeline
from brax.io import mjcf
from brax.positional import pipeline as pos_pipeline
from brax.spring import pipeline as spring_pipeline
Systems = Union[gen_pipeline.System, spring_pipeline.System, pos_pipeline.System]
Pipelines = Union[gen_pipeline.State, spring_pipeline.State, pos_pipeline.State]
except ModuleNotFoundError as e:
print("Brax not installed. Install it with `pip install brax`")
raise e
@struct.dataclass
class BraxParams(base.Base):
max_speed: Union[float, jax.typing.ArrayLike]
damping: Union[float, jax.typing.ArrayLike]
armature: Union[float, jax.typing.ArrayLike]
gear: Union[float, jax.typing.ArrayLike]
mass_weight: Union[float, jax.typing.ArrayLike]
radius_weight: Union[float, jax.typing.ArrayLike]
offset: Union[float, jax.typing.ArrayLike]
friction_loss: Union[float, jax.typing.ArrayLike]
backend: str = struct.field(pytree_node=False)
dt: Union[float, jax.typing.ArrayLike] = struct.field(pytree_node=False)
@property
def substeps(self) -> int:
dt_substeps_per_backend = {"generalized": 1 / 100, "spring": 1 / 100, "positional": 1 / 100}[self.backend]
substeps = ceil(self.dt / dt_substeps_per_backend)
return int(substeps)
@property
def dt_substeps(self) -> float:
substeps = self.substeps
dt_substeps = self.dt / substeps
return dt_substeps
@property
def pipeline(self) -> Pipelines:
return {"generalized": gen_pipeline, "spring": spring_pipeline, "positional": pos_pipeline}[self.backend]
@property
def sys(self) -> Systems:
base_sys = mjcf.loads(DISK_PENDULUM_XML)
# Appropriately replace parameters for the disk pendulum
itransform = base_sys.link.inertia.transform.replace(pos=jnp.array([[0.0, self.offset, 0.0]]))
i = base_sys.link.inertia.i.at[0, 0, 0].set(
0.5 * self.mass_weight * self.radius_weight**2
) # inertia of cylinder in local frame.
inertia = base_sys.link.inertia.replace(transform=itransform, mass=jnp.array([self.mass_weight]), i=i)
link = base_sys.link.replace(inertia=inertia)
actuator = base_sys.actuator.replace(gear=jnp.array([self.gear]))
dof = base_sys.dof.replace(armature=jnp.array([self.armature]), damping=jnp.array([self.damping]))
opt = base_sys.opt.replace(timestep=self.dt_substeps)
new_sys = base_sys.replace(link=link, actuator=actuator, dof=dof, opt=opt)
return new_sys
def step(
self, substeps: int, dt_substeps: jax.typing.ArrayLike, x: Pipelines, us: jax.typing.ArrayLike
) -> Tuple[Pipelines, Pipelines]:
"""Step the pendulum ode."""
# Appropriately replace timestep for the disk pendulum
sys = self.sys.replace(opt=self.sys.opt.replace(timestep=dt_substeps))
def _scan_fn(_x, _u):
# Add friction loss
thdot = x.qd[0]
activation = jnp.sign(thdot)
friction = self.friction_loss * activation / sys.actuator.gear[0]
_u_friction = _u - friction
# Step
next_x = gen_pipeline.step(sys, _x, jnp.array(_u_friction)[None])
# Clip velocity
next_x = next_x.replace(qd=jnp.clip(next_x.qd, -self.max_speed, self.max_speed))
return next_x, next_x
x_final, x_substeps = jax.lax.scan(_scan_fn, x, us, length=substeps)
return x_final, x_substeps
@struct.dataclass
class BraxState(base.Base):
"""Pendulum state definition"""
loss_task: Union[float, jax.typing.ArrayLike]
pipeline_state: Pipelines
@property
def th(self):
return self.pipeline_state.q[..., 0]
@property
def thdot(self):
return self.pipeline_state.qd[..., 0]
@struct.dataclass
class BraxOutput(base.Base):
"""World output definition"""
th: Union[float, jax.typing.ArrayLike]
thdot: Union[float, jax.typing.ArrayLike]
class BraxWorld(BaseWorld): # We inherit from BaseWorld for convenience, but you can inherit from BaseNode if you want
def __init__(self, *args, backend: str = "generalized", **kwargs):
super().__init__(*args, **kwargs)
self.backend = backend
def init_params(self, rng: jax.Array = None, graph_state: GraphState = None) -> BraxParams:
"""Default params of the node."""
return BraxParams(
# Realistic parameters for the disk pendulum
max_speed=40.0,
damping=0.00015877,
armature=6.4940527e-06,
gear=0.00428677,
mass_weight=0.05076142,
radius_weight=0.05121992,
offset=0.04161447,
friction_loss=0.00097525,
# Backend parameters
dt=1 / self.rate,
backend=self.backend,
)
def init_state(self, rng: jax.Array = None, graph_state: GraphState = None) -> BraxState:
"""Default state of the node."""
graph_state = graph_state or GraphState()
# Try to grab state from graph_state
state = graph_state.state.get("agent", None)
init_th = state.init_th if state is not None else jnp.pi
init_thdot = state.init_thdot if state is not None else 0.0
# Set the initial state of the disk pendulum
params = graph_state.params.get(self.name, self.init_params(rng, graph_state))
sys = params.sys
q = sys.init_q.at[0].set(init_th)
qd = jnp.array([init_thdot])
pipeline_state = params.pipeline.init(sys, q, qd)
return BraxState(pipeline_state=pipeline_state, loss_task=0.0)
def init_output(self, rng: jax.Array = None, graph_state: GraphState = None) -> BraxOutput:
"""Default output of the node."""
graph_state = graph_state or GraphState()
# Grab output from state
state = graph_state.state.get(self.name, self.init_state(rng, graph_state))
return BraxOutput(th=state.pipeline_state.q[0], thdot=state.pipeline_state.qd[0])
def step(self, step_state: StepState) -> Tuple[StepState, BraxOutput]:
"""Step the node."""
# Unpack StepState
_, state, params, inputs = step_state.rng, step_state.state, step_state.params, step_state.inputs
# Apply dynamics
u = inputs["actuator"].data.action[-1][0] # [-1] to get the latest action, [0] reduces the dimension to scalar
us = jnp.array([u] * params.substeps)
x = state.pipeline_state
next_x = params.step(params.substeps, params.dt_substeps, x, us)[0]
new_state = state.replace(pipeline_state=next_x)
next_th, next_thdot = new_state.th, new_state.thdot
output = BraxOutput(th=next_th, thdot=next_thdot) # Prepare output
# Calculate cost (penalize angle error, angular velocity and input voltage)
norm_next_th = next_th - 2 * jnp.pi * jnp.floor((next_th + jnp.pi) / (2 * jnp.pi))
loss_task = state.loss_task + norm_next_th**2 + 0.1 * (next_thdot / (1 + 10 * abs(norm_next_th))) ** 2 + 0.01 * u**2
# Update state
new_state = new_state.replace(loss_task=loss_task)
new_step_state = step_state.replace(state=new_state)
return new_step_state, output
DISK_PENDULUM_XML = """
<mujoco model="disk_pendulum">
<compiler inertiafromgeom="auto" angle="radian" coordinate="local" eulerseq="xyz" autolimits="true"/>
<option gravity="0 0 -9.81" timestep="0.01" iterations="10"/>
<custom>
<numeric data="10" name="constraint_ang_damping"/> <!-- positional & spring -->
<numeric data="1" name="spring_inertia_scale"/> <!-- positional & spring -->
<numeric data="0" name="ang_damping"/> <!-- positional & spring -->
<numeric data="0" name="spring_mass_scale"/> <!-- positional & spring -->
<numeric data="0.5" name="joint_scale_pos"/> <!-- positional -->
<numeric data="0.1" name="joint_scale_ang"/> <!-- positional -->
<numeric data="3000" name="constraint_stiffness"/> <!-- spring -->
<numeric data="10000" name="constraint_limit_stiffness"/> <!-- spring -->
<numeric data="50" name="constraint_vel_damping"/> <!-- spring -->
<numeric data="10" name="solver_maxls"/> <!-- generalized -->
</custom>
<asset>
<texture builtin="flat" height="1278" mark="cross" markrgb="1 1 1" name="texgeom" random="0.01" rgb1="0.8 0.6 0.4" rgb2="0.8 0.6 0.4" type="cube" width="127"/>
<material name="geom" texture="texgeom" texuniform="true"/>
</asset>
<default>
<geom contype="0" friction="1 0.1 0.1" material="geom"/>
</default>
<worldbody>
<light cutoff="100" diffuse="1 1 1" dir="-0 0 -1.3" directional="true" exponent="1" pos="0 0 1.3" specular=".1 .1 .1"/>
<geom name="table" type="plane" pos="0 0.0 -0.1" size="1 1 0.1" contype="8" conaffinity="11" condim="3"/>
<body name="disk" pos="0.0 0.0 0.0" euler="1.5708 0.0 0.0">
<joint name="hinge_joint" type="hinge" axis="0 0 1" range="-180 180" armature="0.00022993" damping="0.0001" limited="false"/>
<geom name="disk_geom" type="cylinder" size="0.06 0.001" contype="0" conaffinity="0" condim="3" mass="0.0"/>
<geom name="mass_geom" type="cylinder" size="0.02 0.005" contype="0" conaffinity="0" condim="3" rgba="0.04 0.04 0.04 1"
pos="0.0 0.04 0." mass="0.05085817"/>
</body>
</worldbody>
<actuator>
<motor joint="hinge_joint" ctrllimited="false" ctrlrange="-3.0 3.0" gear="0.01"/>
</actuator>
</mujoco>
"""