Skip to content

Defining Nodes in rex (Robotic Environments with jaX)

¤

This notebook offers an introductory tutorial for rex (Robotic Environments with jaX), a JAX-based framework for creating graph-based environments tailored for sim2real robotics.

In this tutorial, we will guide you through the process of defining nodes, which are the fundamental building blocks for constructing graph-based simulations and real-world systems within rex. Specifically, we will demonstrate how to define the nodes used in the sim2real.ipynb notebook.

# @title Install Necessary Libraries
# @markdown This cell installs the required libraries for the project.
# @markdown If you are running this notebook in Google Colab, most libraries should already be installed.

try:
    import rex  # noqa: F401

    print("Rex already installed")
except ImportError:
    print(
        "Installing rex via `pip install rex-lib[examples]`. "
        "If you are running this in a Colab notebook, you can ignore this message."
    )
    !pip install rex-lib[examples]
Installing rex via `pip install rex-lib[examples]`. If you are running this in a Colab notebook, you can ignore this message.
Collecting rex-lib[examples]
  Downloading rex_lib-0.0.5-py3-none-any.whl.metadata (15 kB)
Collecting dill>=0.3.8 (from rex-lib[examples])
  Downloading dill-0.3.9-py3-none-any.whl.metadata (10 kB)
Collecting distrax>=0.1.5 (from rex-lib[examples])
  Downloading distrax-0.1.5-py3-none-any.whl.metadata (13 kB)
Collecting equinox>=0.11.4 (from rex-lib[examples])
  Downloading equinox-0.11.7-py3-none-any.whl.metadata (18 kB)
Collecting evosax>=0.1.6 (from rex-lib[examples])
  Downloading evosax-0.1.6-py3-none-any.whl.metadata (26 kB)
Requirement already satisfied: flax>=0.8.5 in /usr/local/lib/python3.10/dist-packages (from rex-lib[examples]) (0.8.5)
Collecting gymnasium>=0.29.1 (from rex-lib[examples])
  Downloading gymnasium-1.0.0-py3-none-any.whl.metadata (9.5 kB)
Requirement already satisfied: jax>=0.4.30 in /usr/local/lib/python3.10/dist-packages (from rex-lib[examples]) (0.4.33)
Requirement already satisfied: matplotlib>=3.7.0 in /usr/local/lib/python3.10/dist-packages (from rex-lib[examples]) (3.7.1)
Requirement already satisfied: networkx>=3.2.1 in /usr/local/lib/python3.10/dist-packages (from rex-lib[examples]) (3.3)
Requirement already satisfied: optax>=0.2.3 in /usr/local/lib/python3.10/dist-packages (from rex-lib[examples]) (0.2.3)
Collecting seaborn>=0.13.2 (from rex-lib[examples])
  Downloading seaborn-0.13.2-py3-none-any.whl.metadata (5.4 kB)
Collecting supergraph>=0.0.8 (from rex-lib[examples])
  Downloading supergraph-0.0.8-py3-none-any.whl.metadata (1.2 kB)
Requirement already satisfied: termcolor>=2.4.0 in /usr/local/lib/python3.10/dist-packages (from rex-lib[examples]) (2.4.0)
Requirement already satisfied: tqdm>=4.66.4 in /usr/local/lib/python3.10/dist-packages (from rex-lib[examples]) (4.66.5)
Collecting brax>=0.10.5 (from rex-lib[examples])
  Downloading brax-0.11.0-py3-none-any.whl.metadata (7.7 kB)
Requirement already satisfied: absl-py in /usr/local/lib/python3.10/dist-packages (from brax>=0.10.5->rex-lib[examples]) (1.4.0)
Collecting dm-env (from brax>=0.10.5->rex-lib[examples])
  Downloading dm_env-1.6-py3-none-any.whl.metadata (966 bytes)
Requirement already satisfied: etils in /usr/local/lib/python3.10/dist-packages (from brax>=0.10.5->rex-lib[examples]) (1.9.4)
Requirement already satisfied: flask in /usr/local/lib/python3.10/dist-packages (from brax>=0.10.5->rex-lib[examples]) (2.2.5)
Collecting flask-cors (from brax>=0.10.5->rex-lib[examples])
  Downloading Flask_Cors-5.0.0-py2.py3-none-any.whl.metadata (5.5 kB)
Requirement already satisfied: grpcio in /usr/local/lib/python3.10/dist-packages (from brax>=0.10.5->rex-lib[examples]) (1.64.1)
Requirement already satisfied: gym in /usr/local/lib/python3.10/dist-packages (from brax>=0.10.5->rex-lib[examples]) (0.25.2)
Requirement already satisfied: jaxlib>=0.4.6 in /usr/local/lib/python3.10/dist-packages (from brax>=0.10.5->rex-lib[examples]) (0.4.33)
Collecting jaxopt (from brax>=0.10.5->rex-lib[examples])
  Downloading jaxopt-0.8.3-py3-none-any.whl.metadata (2.6 kB)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from brax>=0.10.5->rex-lib[examples]) (3.1.4)
Collecting ml-collections (from brax>=0.10.5->rex-lib[examples])
  Downloading ml_collections-0.1.1.tar.gz (77 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 77.9/77.9 kB 1.8 MB/s eta 0:00:00
  Preparing metadata (setup.py) ... done
Collecting mujoco (from brax>=0.10.5->rex-lib[examples])
  Downloading mujoco-3.2.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (44 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 44.4/44.4 kB 1.2 MB/s eta 0:00:00
Collecting mujoco-mjx (from brax>=0.10.5->rex-lib[examples])
  Downloading mujoco_mjx-3.2.3-py3-none-any.whl.metadata (3.4 kB)
Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from brax>=0.10.5->rex-lib[examples]) (1.26.4)
Requirement already satisfied: orbax-checkpoint in /usr/local/lib/python3.10/dist-packages (from brax>=0.10.5->rex-lib[examples]) (0.6.4)
Requirement already satisfied: Pillow in /usr/local/lib/python3.10/dist-packages (from brax>=0.10.5->rex-lib[examples]) (10.4.0)
Collecting pytinyrenderer (from brax>=0.10.5->rex-lib[examples])
  Downloading pytinyrenderer-0.0.14-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.2 kB)
Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from brax>=0.10.5->rex-lib[examples]) (1.13.1)
Collecting tensorboardX (from brax>=0.10.5->rex-lib[examples])
  Downloading tensorboardX-2.6.2.2-py2.py3-none-any.whl.metadata (5.8 kB)
Collecting trimesh (from brax>=0.10.5->rex-lib[examples])
  Downloading trimesh-4.4.9-py3-none-any.whl.metadata (18 kB)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from brax>=0.10.5->rex-lib[examples]) (4.12.2)
Requirement already satisfied: chex>=0.1.8 in /usr/local/lib/python3.10/dist-packages (from distrax>=0.1.5->rex-lib[examples]) (0.1.87)
Requirement already satisfied: tensorflow-probability>=0.15.0 in /usr/local/lib/python3.10/dist-packages (from distrax>=0.1.5->rex-lib[examples]) (0.24.0)
Collecting jaxtyping>=0.2.20 (from equinox>=0.11.4->rex-lib[examples])
  Downloading jaxtyping-0.2.34-py3-none-any.whl.metadata (6.4 kB)
Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from evosax>=0.1.6->rex-lib[examples]) (6.0.2)
Collecting dotmap (from evosax>=0.1.6->rex-lib[examples])
  Downloading dotmap-1.3.30-py3-none-any.whl.metadata (3.2 kB)
Requirement already satisfied: msgpack in /usr/local/lib/python3.10/dist-packages (from flax>=0.8.5->rex-lib[examples]) (1.0.8)
Requirement already satisfied: tensorstore in /usr/local/lib/python3.10/dist-packages (from flax>=0.8.5->rex-lib[examples]) (0.1.66)
Requirement already satisfied: rich>=11.1 in /usr/local/lib/python3.10/dist-packages (from flax>=0.8.5->rex-lib[examples]) (13.9.1)
Requirement already satisfied: cloudpickle>=1.2.0 in /usr/local/lib/python3.10/dist-packages (from gymnasium>=0.29.1->rex-lib[examples]) (2.2.1)
Collecting farama-notifications>=0.0.1 (from gymnasium>=0.29.1->rex-lib[examples])
  Downloading Farama_Notifications-0.0.4-py3-none-any.whl.metadata (558 bytes)
Requirement already satisfied: ml-dtypes>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from jax>=0.4.30->rex-lib[examples]) (0.4.1)
Requirement already satisfied: opt-einsum in /usr/local/lib/python3.10/dist-packages (from jax>=0.4.30->rex-lib[examples]) (3.4.0)
Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.7.0->rex-lib[examples]) (1.3.0)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.7.0->rex-lib[examples]) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.7.0->rex-lib[examples]) (4.54.1)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.7.0->rex-lib[examples]) (1.4.7)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.7.0->rex-lib[examples]) (24.1)
Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.7.0->rex-lib[examples]) (3.1.4)
Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.7.0->rex-lib[examples]) (2.8.2)
Requirement already satisfied: pandas>=1.2 in /usr/local/lib/python3.10/dist-packages (from seaborn>=0.13.2->rex-lib[examples]) (2.2.2)
Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.10/dist-packages (from chex>=0.1.8->distrax>=0.1.5->rex-lib[examples]) (0.12.1)
Collecting typeguard==2.13.3 (from jaxtyping>=0.2.20->equinox>=0.11.4->rex-lib[examples])
  Downloading typeguard-2.13.3-py3-none-any.whl.metadata (3.6 kB)
Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.2->seaborn>=0.13.2->rex-lib[examples]) (2024.2)
Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.2->seaborn>=0.13.2->rex-lib[examples]) (2024.2)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.7->matplotlib>=3.7.0->rex-lib[examples]) (1.16.0)
Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich>=11.1->flax>=0.8.5->rex-lib[examples]) (3.0.0)
Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich>=11.1->flax>=0.8.5->rex-lib[examples]) (2.18.0)
Requirement already satisfied: decorator in /usr/local/lib/python3.10/dist-packages (from tensorflow-probability>=0.15.0->distrax>=0.1.5->rex-lib[examples]) (4.4.2)
Requirement already satisfied: gast>=0.3.2 in /usr/local/lib/python3.10/dist-packages (from tensorflow-probability>=0.15.0->distrax>=0.1.5->rex-lib[examples]) (0.6.0)
Requirement already satisfied: dm-tree in /usr/local/lib/python3.10/dist-packages (from tensorflow-probability>=0.15.0->distrax>=0.1.5->rex-lib[examples]) (0.1.8)
Requirement already satisfied: Werkzeug>=2.2.2 in /usr/local/lib/python3.10/dist-packages (from flask->brax>=0.10.5->rex-lib[examples]) (3.0.4)
Requirement already satisfied: itsdangerous>=2.0 in /usr/local/lib/python3.10/dist-packages (from flask->brax>=0.10.5->rex-lib[examples]) (2.2.0)
Requirement already satisfied: click>=8.0 in /usr/local/lib/python3.10/dist-packages (from flask->brax>=0.10.5->rex-lib[examples]) (8.1.7)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->brax>=0.10.5->rex-lib[examples]) (2.1.5)
Requirement already satisfied: gym-notices>=0.0.4 in /usr/local/lib/python3.10/dist-packages (from gym->brax>=0.10.5->rex-lib[examples]) (0.0.8)
Requirement already satisfied: contextlib2 in /usr/local/lib/python3.10/dist-packages (from ml-collections->brax>=0.10.5->rex-lib[examples]) (21.6.0)
Collecting glfw (from mujoco->brax>=0.10.5->rex-lib[examples])
  Downloading glfw-2.7.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38-none-manylinux2014_x86_64.whl.metadata (5.4 kB)
Requirement already satisfied: pyopengl in /usr/local/lib/python3.10/dist-packages (from mujoco->brax>=0.10.5->rex-lib[examples]) (3.1.7)
Requirement already satisfied: nest_asyncio in /usr/local/lib/python3.10/dist-packages (from orbax-checkpoint->brax>=0.10.5->rex-lib[examples]) (1.6.0)
Requirement already satisfied: protobuf in /usr/local/lib/python3.10/dist-packages (from orbax-checkpoint->brax>=0.10.5->rex-lib[examples]) (3.20.3)
Requirement already satisfied: humanize in /usr/local/lib/python3.10/dist-packages (from orbax-checkpoint->brax>=0.10.5->rex-lib[examples]) (4.10.0)
Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich>=11.1->flax>=0.8.5->rex-lib[examples]) (0.1.2)
Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from etils->brax>=0.10.5->rex-lib[examples]) (2024.6.1)
Requirement already satisfied: importlib_resources in /usr/local/lib/python3.10/dist-packages (from etils->brax>=0.10.5->rex-lib[examples]) (6.4.5)
Requirement already satisfied: zipp in /usr/local/lib/python3.10/dist-packages (from etils->brax>=0.10.5->rex-lib[examples]) (3.20.2)
Downloading brax-0.11.0-py3-none-any.whl (998 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 998.6/998.6 kB 11.8 MB/s eta 0:00:00
Downloading dill-0.3.9-py3-none-any.whl (119 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 119.4/119.4 kB 5.3 MB/s eta 0:00:00
Downloading distrax-0.1.5-py3-none-any.whl (319 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 319.7/319.7 kB 9.4 MB/s eta 0:00:00
Downloading equinox-0.11.7-py3-none-any.whl (178 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 178.4/178.4 kB 7.5 MB/s eta 0:00:00
Downloading evosax-0.1.6-py3-none-any.whl (240 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 240.4/240.4 kB 8.5 MB/s eta 0:00:00
Downloading gymnasium-1.0.0-py3-none-any.whl (958 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 958.1/958.1 kB 13.3 MB/s eta 0:00:00
Downloading seaborn-0.13.2-py3-none-any.whl (294 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 294.9/294.9 kB 8.4 MB/s eta 0:00:00
Downloading supergraph-0.0.8-py3-none-any.whl (65 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 65.5/65.5 kB 2.5 MB/s eta 0:00:00
Downloading rex_lib-0.0.5-py3-none-any.whl (115 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 115.1/115.1 kB 4.9 MB/s eta 0:00:00
Downloading Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)
Downloading jaxtyping-0.2.34-py3-none-any.whl (42 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 42.4/42.4 kB 1.7 MB/s eta 0:00:00
Downloading typeguard-2.13.3-py3-none-any.whl (17 kB)
Downloading dm_env-1.6-py3-none-any.whl (26 kB)
Downloading dotmap-1.3.30-py3-none-any.whl (11 kB)
Downloading Flask_Cors-5.0.0-py2.py3-none-any.whl (14 kB)
Downloading jaxopt-0.8.3-py3-none-any.whl (172 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 172.3/172.3 kB 5.2 MB/s eta 0:00:00
Downloading mujoco-3.2.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (6.1 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.1/6.1 MB 23.7 MB/s eta 0:00:00
Downloading mujoco_mjx-3.2.3-py3-none-any.whl (6.7 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.7/6.7 MB 12.5 MB/s eta 0:00:00
Downloading pytinyrenderer-0.0.14-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.9 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.9/1.9 MB 13.9 MB/s eta 0:00:00
Downloading tensorboardX-2.6.2.2-py2.py3-none-any.whl (101 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 101.7/101.7 kB 3.1 MB/s eta 0:00:00
Downloading trimesh-4.4.9-py3-none-any.whl (700 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 700.1/700.1 kB 20.4 MB/s eta 0:00:00
Downloading glfw-2.7.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38-none-manylinux2014_x86_64.whl (211 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 211.8/211.8 kB 7.5 MB/s eta 0:00:00
Building wheels for collected packages: ml-collections
  Building wheel for ml-collections (setup.py) ... done
  Created wheel for ml-collections: filename=ml_collections-0.1.1-py3-none-any.whl size=94507 sha256=8b83b1225aa4d52136d84206a5cb94da537f08a16dbd7b480fa90dd833c1cf78
  Stored in directory: /root/.cache/pip/wheels/7b/89/c9/a9b87790789e94aadcfc393c283e3ecd5ab916aed0a31be8fe
Successfully built ml-collections
Installing collected packages: pytinyrenderer, glfw, farama-notifications, dotmap, typeguard, trimesh, tensorboardX, supergraph, ml-collections, gymnasium, dm-env, dill, jaxtyping, seaborn, mujoco, flask-cors, mujoco-mjx, jaxopt, equinox, distrax, evosax, brax, rex-lib
  Attempting uninstall: typeguard
    Found existing installation: typeguard 4.3.0
    Uninstalling typeguard-4.3.0:
      Successfully uninstalled typeguard-4.3.0
  Attempting uninstall: seaborn
    Found existing installation: seaborn 0.13.1
    Uninstalling seaborn-0.13.1:
      Successfully uninstalled seaborn-0.13.1
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
inflect 7.4.0 requires typeguard>=4.0.1, but you have typeguard 2.13.3 which is incompatible.
Successfully installed brax-0.11.0 dill-0.3.9 distrax-0.1.5 dm-env-1.6 dotmap-1.3.30 equinox-0.11.7 evosax-0.1.6 farama-notifications-0.0.4 flask-cors-5.0.0 glfw-2.7.0 gymnasium-1.0.0 jaxopt-0.8.3 jaxtyping-0.2.34 ml-collections-0.1.1 mujoco-3.2.3 mujoco-mjx-3.2.3 pytinyrenderer-0.0.14 rex-lib-0.0.5 seaborn-0.13.2 supergraph-0.0.8 tensorboardX-2.6.2.2 trimesh-4.4.9 typeguard-2.13.3

Introduction to Nodes in Rex¤

In Rex, a node represents a fundamental computational unit within a graph-based system. Nodes encapsulate specific functionality and interact by passing data through connections, forming a network that can model complex systems. This tutorial introduces how to define nodes, specify their properties like rates and delays, and manage their interactions within a graph.

Defining Nodes¤

Nodes are defined by creating subclasses of the BaseNode class. This base class provides a standardized API and essential functionality that all nodes inherit. When defining a node, you can specify several parameters directly in the __init__ method:

  • name: A unique identifier for the node.
  • rate: The frequency at which the node's step method is called (in Hz).
  • delay (optional): The expected computation delay of the node (in seconds).
  • delay_dist: A distribution representing variability in the node's computation delay, useful for simulations.
  • advance: If True, the node's step method triggers when all inputs are ready; if False, it throttles until the scheduled time.
  • scheduling: Determines how the node's execution is scheduled. Options include Scheduling.FREQUENCY and Scheduling.PHASE.
  • color: Used for visualization purposes.
  • order: Determines the node's order in visualizations.

Here's a basic example of a node definition:

class MyNode(BaseNode):
    def __init__(
        self,
        name: str,
        rate: float,
        delay: float = None,  # Expected computation delay (used for phase-shifting)
        delay_dist: Union[DelayDistribution, distrax.Distribution] = None,  # Sim. computation delay
        advance: bool = False,
        scheduling: Scheduling = Scheduling.FREQUENCY,
        color: str = None,
        order: int = None
    ):
        super().__init__(name, rate, delay, delay_dist, advance, scheduling, color, order)
        # Additional initialization if needed

    def init_params(self, rng: jax.Array = None, graph_state: GraphState = None):
        # Initialize parameters
        return MyParams()

    def init_state(self, rng: jax.Array = None, graph_state: GraphState = None):
        # Initialize state
        return MyState()

    def init_output(self, rng: jax.Array = None, graph_state: GraphState = None):
        # Initialize default output
        return MyOutput()

    def step(self, step_state: StepState):
        # Node's computation logic
        new_state = ...
        output = ...
        return step_state.replace(state=new_state), output

Connecting Nodes¤

Nodes interact by passing outputs from one node to the inputs of another. This is achieved through the connect method, which establishes a connection between two nodes.

Connection API¤

When connecting nodes, you can specify several parameters that control the nature of the connection:

  • output_node: node whose output will be connected as an input.
  • blocking: True, the receiving node waits for the input before proceeding. This can create dependencies between nodes.
  • delay: An additional delay introduced in the connection, which can control the phase shift between nodes.
  • delay_dist: Used in simulation to model communication delays between nodes.
  • window: Determines how many past messages are stored and accessible in the input buffer.
  • skip: If True, the connection is skipped when messages arrive simultaneously, helping resolve cyclic dependencies.
  • jitter: Controls how to handle irregularities in message timing (e.g., Jitter.LATEST uses the most recent message).
  • name: A shadow name for the input; defaults to the output node's name.

Including delay_dist in Connection¤

The delay_dist parameter allows you to specify a distribution that models the variability in communication delay between nodes. This is particularly useful in simulations where network latency or message passing delays are significant.

Resolving Cyclic Dependencies with skip¤

In graphs where nodes depend on each other's outputs (creating a cycle), the skip parameter can be used to resolve the dependency. By setting skip=True on a connection, you instruct the receiving node to proceed without waiting for the current message if it arrives simultaneously. This breaks the cycle and allows the system to function.

Example Connection¤

node_a.connect(
    output_node=node_b,
    blocking=True,
    delay=0.01,  # Expected communication delay (used for phase-shifting)
    delay_dist=distrax.Normal(loc=0.01, scale=0.005), # Sim. communication delay
    window=5,
    skip=False,
    jitter=Jitter.LATEST,
    name="input_from_b"
)

In this example, node_a connects to node_b with a blocking connection, an added delay of 0.01 seconds, and a delay distribution for simulation purposes. The window size is set to 5, meaning the last five messages are stored. The skip parameter is False, so the node will wait for the input.

Node Data Structure¤

Nodes manage four main types of data (defined as pytrees), typically defined using immutable dataclasses for efficiency and safety:

  1. Parameters: Static configurations that usually remain constant during execution.
  2. State: Dynamic data that evolves over time with each step.
  3. Outputs: Data produced by a node's step method and sent to connected nodes.
  4. Inputs: Buffers that hold incoming data from other nodes, respecting the specified window size.

Immutable Dataclasses¤

Using immutable dataclasses (e.g., via @struct.dataclass from Flax) ensures that the data structures are compatible with JAX's JIT compilation and functional programming paradigms. Additionally, dataclasses allow you to define specific methods related to the data structure, providing encapsulation and clarity.

@struct.dataclass
class MyParams:
    some_parameter: float

    def adjust_parameter(self, factor: float):
        return self.replace(some_parameter=self.some_parameter * factor)

@struct.dataclass
class MyState:
    some_state_variable: jax.Array

    def update_state(self, delta: jax.Array):
        return self.replace(some_state_variable=self.some_state_variable + delta)

@struct.dataclass
class MyOutput:
    some_output_data: jax.Array

In this example, MyParams and MyState include methods to adjust parameters and update state, respectively. This encapsulation enhances code organization and readability.

Initialization¤

Node data is initialized using specific methods that you should override:

  • init_params: Initializes the node's parameters.
  • init_state: Initializes the node's state.
  • init_output: Provides a default output, useful for initializing input buffers in connected nodes.

These methods are typically called during the graph's initialization phase using graph.init().

The step Method in Detail¤

The step method defines how a node processes inputs and updates its state at each timestep. It receives a StepState object with all necessary information.

StepState Attributes¤

  • rng: Random number generator (updated if used).
  • state: Node's current state.
  • params: Static parameters influencing behavior.
  • inputs: Dictionary of InputState instances (keyed by input names).
  • eps: Episode number relates to the current computation graph used for simulation (unrelated to RL episode number).
  • seq: Current step number (auto-increments with each step).
  • ts: Timestamp at the start of the step.

Accessing Inputs¤

Each InputState in step_state.inputs contains:

  • data: Messages from the connected node.
  • seq: Sequence numbers of the received messages.
  • ts_sent: Timestamps when messages were sent.
  • ts_recv: Timestamps when messages were received.

For example, accessing the most recent message:

latest_sensor_input = step_state.inputs['sensor'][-1].data

Implementing the step Method¤

The typical steps to implement the step method can be condensed into the following block:

def step(self, step_state: StepState):
    # Unpack StepState
    rng, state, params, inputs = step_state.rng, step_state.state, step_state.params, step_state.inputs

    # Access latest input
    control_signal = inputs['controller'][-1].data

    # Update state
    new_state_variable = state.some_state_variable + control_signal * params.gain
    new_state = state.replace(some_state_variable=new_state_variable)

    # Produce output
    output = MyOutput(some_output_data=new_state_variable)

    # Update RNG if randomness is involved
    rng, _ = jax.random.split(rng)

    # Return updated StepState and output
    return step_state.replace(state=new_state, rng=rng), output

Working with Time and Sequence¤

Use eps, ts and seq for time-dependent logic:

if step_state.ts > params.activation_time:
    # Perform time-based logic
    pass

Handling Input Windows¤

If the input window size is greater than 1, you can access past messages:

recent_sensor_data = inputs['sensor_input'][-3:].data

JIT Compilation and Side Effects Handling with External Callbacks¤

Rex advocates for JIT-compiling the step method of each node to enhance performance. However, interfacing with real hardware often involves side effects that JAX's JIT compilation doesn't handle natively.

To include side-effecting code (e.g., sending commands to actuators, reading sensor data), you must use JAX's external callback mechanism. This involves wrapping side-effecting functions with jax.experimental.io_callback to ensure compatibility with JIT compilation.

Refer to the JAX documentation on external callbacks for detailed guidance.

def step(self, step_state: StepState):
    # Compute outputs
    output = ...

    # Side-effecting function
    def _apply_action(action):
        # Code that interacts with hardware
        return np.array(1.0)  # Dummy return value

    # Wrap side-effecting code
    _ = jax.experimental.io_callback(
        _apply_action,
        result_shape=jnp.array(1.0),
        arg=output.some_output_data
    )

    # Update state and return
    return step_state, output

Real-World Nodes and Lifecycle Methods¤

When nodes interface with real hardware or external systems, additional lifecycle management is necessary. The BaseNode API accommodates this through:

  • startup: Called before an episode starts, allowing the node to prepare (e.g., initialize hardware).
  • stop: Called after an episode ends, enabling the node to clean up resources or safely shut down hardware.
class RealWorldNode(BaseNode):
    def __init__(
        self,
        name: str,
        rate: float,
        delay: float = None,
        delay_dist: Union[DelayDistribution, distrax.Distribution] = None,
        advance: bool = False,
        scheduling: Scheduling = Scheduling.FREQUENCY,
        color: str = None,
        order: int = None
    ):
        super().__init__(name, rate, delay, delay_dist, advance, scheduling, color, order)
        # Additional initialization if needed

    def startup(self, graph_state: GraphState, timeout: float = None):
        # Initialize hardware connections
        return True  # Return True if successful

    def stop(self, timeout: float = None):
        # Safely shut down hardware
        return True

Summary¤

By following these guidelines, you can define robust and efficient nodes within the Rex framework. Nodes can be customized extensively through their parameters and state, connected flexibly to form complex graphs, and optimized using JIT compilation. Proper handling of side effects ensures that nodes interfacing with real-world systems remain performant and reliable.

In the following examples, we'll implement specific nodes that illustrate these concepts in practice.

# @title Example: Agent

from typing import Tuple, Union

import jax
from flax import struct
from flax.core import FrozenDict
from jax import numpy as jnp

from rex import base
from rex.base import GraphState, StepState
from rex.node import BaseNode
from rex.ppo import Policy


@struct.dataclass
class AgentOutput(base.Base):
    """Agent's output"""

    action: jax.typing.ArrayLike  # Torque to apply to the pendulum


@struct.dataclass
class AgentParams(base.Base):
    # Policy
    policy: Policy
    # Observations
    num_act: Union[int, jax.typing.ArrayLike] = struct.field(pytree_node=False)  # Action history length
    num_obs: Union[int, jax.typing.ArrayLike] = struct.field(pytree_node=False)  # Observation history length
    # Action
    max_torque: Union[float, jax.typing.ArrayLike]
    # Initial state
    init_method: str = struct.field(pytree_node=False)  # "random", "parametrized"
    parametrized: jax.typing.ArrayLike
    max_th: Union[float, jax.typing.ArrayLike]
    max_thdot: Union[float, jax.typing.ArrayLike]
    # Train
    gamma: Union[float, jax.typing.ArrayLike]
    tmax: Union[float, jax.typing.ArrayLike]

    @staticmethod
    def process_inputs(inputs: FrozenDict[str, base.InputState]) -> jax.Array:
        th, thdot = inputs["sensor"][-1].data.th, inputs["sensor"][-1].data.thdot
        obs = jnp.array([jnp.cos(th), jnp.sin(th), thdot])
        return obs

    @staticmethod
    def get_observation(step_state: StepState) -> jax.Array:
        # Unpack StepState
        inputs, state = step_state.inputs, step_state.state

        # Convert inputs to single observation
        single_obs = AgentParams.process_inputs(inputs)

        # Concatenate with previous observations
        obs = jnp.concatenate([single_obs, state.history_obs.flatten(), state.history_act.flatten()])
        return obs

    @staticmethod
    def update_state(step_state: StepState, action: jax.Array) -> "AgentState":
        # Unpack StepState
        state, params, inputs = step_state.state, step_state.params, step_state.inputs

        # Convert inputs to observation
        single_obs = AgentParams.process_inputs(inputs)

        # Update obs history
        if params.num_obs > 0:
            history_obs = jnp.roll(state.history_obs, shift=1, axis=0)
            history_obs = history_obs.at[0].set(single_obs)
        else:
            history_obs = state.history_obs

        # Update act history
        if params.num_act > 0:
            history_act = jnp.roll(state.history_act, shift=1, axis=0)
            history_act = history_act.at[0].set(action)
        else:
            history_act = state.history_act

        new_state = state.replace(history_obs=history_obs, history_act=history_act)
        return new_state

    @staticmethod
    def to_output(action: jax.Array) -> AgentOutput:
        return AgentOutput(action=action)


@struct.dataclass
class AgentState(base.Base):
    history_act: jax.typing.ArrayLike  # History of actions
    history_obs: jax.typing.ArrayLike  # History of observations
    init_th: Union[float, jax.typing.ArrayLike]  # Initial angle of the pendulum
    init_thdot: Union[float, jax.typing.ArrayLike]  # Initial angular velocity of the pendulum


class Agent(BaseNode):
    def init_params(self, rng: jax.Array = None, graph_state: GraphState = None) -> AgentParams:
        return AgentParams(
            policy=None,  # Policy must be set by the user
            num_act=4,  # Number of actions to keep in history
            num_obs=4,  # Number of observations to keep in history
            max_torque=2.0,  # Maximum torque that can be applied to the pendulum
            init_method="parametrized",  # "random" or "parametrized"
            parametrized=jnp.array([jnp.pi, 0.0]),  # [th, thdot]
            max_th=jnp.pi,  # Maximum initial angle of the pendulum
            max_thdot=9.0,  # Maximum initial angular velocity of the pendulum
            gamma=0.99,  # Discount factor  (used during training)
            tmax=3.0,  # Maximum time for an episode (used during training)
        )

    def init_state(self, rng: jax.Array = None, graph_state: GraphState = None) -> AgentState:
        graph_state = graph_state or base.GraphState()
        params = graph_state.params.get(self.name, self.init_params(rng, graph_state))
        history_act = jnp.zeros((params.num_act, 1), dtype=jnp.float32)  # [torque]
        history_obs = jnp.zeros((params.num_obs, 3), dtype=jnp.float32)  # [cos(th), sin(th), thdot]

        # Set the initial state of the pendulum
        if params.init_method == "parametrized":
            init_th, init_thdot = params.parametrized
        elif params.init_method == "random":
            rng = rng if rng is not None else jax.random.PRNGKey(0)
            rngs = jax.random.split(rng, num=2)
            init_th = jax.random.uniform(rngs[0], shape=(), minval=-params.max_th, maxval=params.max_th)
            init_thdot = jax.random.uniform(rngs[1], shape=(), minval=-params.max_thdot, maxval=params.max_thdot)
        else:
            raise ValueError(f"Invalid init_method: {params.init_method}")
        return AgentState(history_act=history_act, history_obs=history_obs, init_th=init_th, init_thdot=init_thdot)

    def init_output(self, rng: jax.Array = None, graph_state: GraphState = None) -> AgentOutput:
        """Default output of the node."""
        rng = jax.random.PRNGKey(0) if rng is None else rng
        graph_state = graph_state or base.GraphState()
        params = graph_state.params.get(self.name, self.init_params(rng, graph_state))
        action = jax.random.uniform(rng, shape=(1,), minval=-params.max_torque, maxval=params.max_torque)
        return AgentOutput(action=action)

    def step(self, step_state: StepState) -> Tuple[StepState, AgentOutput]:
        """Step the node."""
        # Unpack StepState
        rng, params = step_state.rng, step_state.params

        # Prepare output
        rng, rng_net = jax.random.split(rng)
        if params.policy is not None:  # Use policy to get action
            obs = AgentParams.get_observation(step_state)
            action = params.policy.get_action(obs, rng=None)  # Supply rng for stochastic policies
        else:  # Random action if no policy is set
            action = jax.random.uniform(rng_net, shape=(1,), minval=-params.max_torque, maxval=params.max_torque)
        output = AgentParams.to_output(action)  # Convert action to output message

        # Update step_state (observation and action history)
        new_state = params.update_state(step_state, action)  # Update state
        new_step_state = step_state.replace(rng=rng, state=new_state)  # Update step_state
        return new_step_state, output
# @title Example: Actuator

from typing import Tuple, Union

import jax
import numpy as onp
from flax import struct

from rex import base
from rex.base import GraphState, StepState
from rex.jax_utils import tree_dynamic_slice
from rex.node import BaseNode


@struct.dataclass
class ActuatorOutput(base.Base):
    """Pendulum actuator output"""

    action: jax.typing.ArrayLike  # Torque to apply to the pendulum


@struct.dataclass
class ActuatorParams(base.Base):
    """Pendulum actuator param definition"""

    actuator_delay: Union[float, jax.typing.ArrayLike]


class Actuator(BaseNode):
    """This is a simple actuator node definition that could interface a real actuator.

    When interfacing real hardware, you would send the action to real hardware in the .step method.
    Optionally, you could also specify a startup routine that is called right before an episode starts.
    Finally, a stop routine is called after the episode is done.
    """

    def __init__(self, *args, **kwargs):
        """No special initialization needed."""
        super().__init__(*args, **kwargs)

    def init_params(self, rng: jax.Array = None, graph_state: GraphState = None) -> ActuatorParams:
        """Default params of the node."""
        actuator_delay = 0.05
        return ActuatorParams(actuator_delay=actuator_delay)

    def init_output(self, rng: jax.Array = None, graph_state: GraphState = None) -> ActuatorOutput:
        """Default output of the node."""
        return ActuatorOutput(action=jnp.array([0.0], dtype=jnp.float32))

    def startup(self, graph_state: base.GraphState, timeout: float = None) -> bool:
        """Starts the node in the state specified by graph_state.

        This method is called right before an episode starts.
        It can be used to move (a real) robot to a starting position as specified by the graph_state.

        Not used when running in compiled mode.
        :param graph_state: The graph state.
        :param timeout: The timeout of the startup.
        :return: Whether the node has started successfully.
        """
        # Move robot to starting position specified by graph_state (e.g. graph_state.state["agent"].init_th)
        return True  # Not doing anything here

    def step(self, step_state: StepState) -> Tuple[StepState, ActuatorOutput]:
        """If we were to control a real robot, you would send the action to the robot here."""
        # Prepare output
        output = step_state.inputs["agent"][-1].data
        output = ActuatorOutput(action=output.action)

        def _apply_action(action):
            """
            Not really doing anything here, just a dummy implementation.
            Include some side-effecting code here (e.g. sending the action to a real robot).

            The .step method may be jit-compiled, it is important to wrap any side-effecting code in a host_callback.
            See the jax documentation for more information on how to do this:
            https://jax.readthedocs.io/en/latest/external-callbacks.html
            """
            # print(f"Applying action: {action}") # Apply action to the robot
            return onp.array(1.0)  # Must match dtype and shape of return_shape

        # Apply action to the robot
        return_shape = jnp.array(1.0)  # Must match dtype and shape of return_shape
        _ = jax.experimental.io_callback(_apply_action, return_shape, output)

        # Update state
        new_step_state = step_state
        return new_step_state, output

    def stop(self, timeout: float = None) -> bool:
        """Stopping routine that is called after the episode is done.

        **IMPORTANT** It may happen that stop is called *before* the final .step call of an episode returns,
        which may cause unsafe behavior when the final step undoes the work of the .stop method.
        This should be handled by the user. For example, by stopping "longer" before returning here.

        Only ran when running asynchronously.
        :param timeout: The timeout of the stop
        :return: Whether the node has stopped successfully.
        """
        # Stop the robot (e.g. set the torque to 0)
        return True


class SimActuator(BaseNode):
    """This is a simple simulated actuator node definition that can either
    1. Feedthrough the agent's action (for normal operation, e.g., training).
       Optionally, you could include some noise or other modifications to the action.
    2. Reapply the recorded actuator outputs for system identification if available.
    """

    def __init__(self, *args, outputs: ActuatorOutput = None, **kwargs):
        """Initialize Actuator for system identification.

        Here, we will reapply the recorded actuator outputs for system identification if available.

        :param outputs: Recorded actuator Outputs to be used for system identification.
        """
        super().__init__(*args, **kwargs)
        self._outputs = outputs

    def init_params(self, rng: jax.Array = None, graph_state: GraphState = None) -> ActuatorParams:
        """Default params of the node."""
        actuator_delay = 0.05
        return ActuatorParams(actuator_delay=actuator_delay)

    def init_output(self, rng: jax.Array = None, graph_state: GraphState = None) -> ActuatorOutput:
        """Default output of the node."""
        return ActuatorOutput(action=jnp.array([0.0], dtype=jnp.float32))

    def step(self, step_state: StepState) -> Tuple[StepState, ActuatorOutput]:
        # Get action from dataset if available, else use the one provided by the agent
        if self._outputs is not None:  # Use the recorded action (for system identification)
            output = tree_dynamic_slice(self._outputs, jnp.array([step_state.eps, step_state.seq]))
            output = jax.tree_util.tree_map(lambda _o: _o[0, 0], output)
        else:  # Feedthrough the agent's action (for normal operation, e.g., training)
            output = step_state.inputs["agent"][-1].data
            output = ActuatorOutput(action=output.action)
        new_step_state = step_state
        return new_step_state, output
# @title Example: Sensor

from typing import Dict, Tuple, Union

import jax
from flax import struct

from rex import base
from rex.base import GraphState, StepState
from rex.node import BaseNode


@struct.dataclass
class SensorOutput(base.Base):
    """Output message definition of the sensor node."""

    th: Union[float, jax.typing.ArrayLike]
    thdot: Union[float, jax.typing.ArrayLike]


@struct.dataclass
class SensorParams(base.Base):
    """
    Other than the sensor delay, we don't have any other parameters.
    You could add more parameters here if needed, such as noise levels etc.
    """

    sensor_delay: Union[float, jax.typing.ArrayLike]


@struct.dataclass
class SensorState:
    """We use this state to record the reconstruction loss."""

    loss_th: Union[float, jax.typing.ArrayLike]
    loss_thdot: Union[float, jax.typing.ArrayLike]


class Sensor(BaseNode):
    """This is a simple sensor node definition that interfaces a real sensor.

    When interfacing real hardware, you would grab the sensor measurement in the .step method.
    Optionally, you could also specify a startup routine that is called right before an episode starts.
    Finally, a stop routine is called after the episode is done.
    """

    def __init__(self, *args, **kwargs):
        """No special initialization needed."""
        super().__init__(*args, **kwargs)

    def init_params(self, rng: jax.Array = None, graph_state: GraphState = None) -> SensorParams:
        """Default params of the node."""
        sensor_delay = 0.05
        return SensorParams(sensor_delay=sensor_delay)

    def init_output(self, rng: jax.Array = None, graph_state: GraphState = None) -> SensorOutput:
        """Default output of the node."""
        # Randomly define some initial sensor values
        th = jnp.pi
        thdot = 0.0
        return SensorOutput(th=th, thdot=thdot)

    def startup(self, graph_state: base.GraphState, timeout: float = None) -> bool:
        """Starts the node in the state specified by graph_state.

        This method is called right before an episode starts.
        It can be used to move (a real) robot to a starting position as specified by the graph_state.

        Not used when running in compiled mode.
        :param graph_state: The graph state.
        :param timeout: The timeout of the startup.
        :return: Whether the node has started successfully.
        """
        return True  # Not doing anything here

    def step(self, step_state: StepState) -> Tuple[StepState, SensorOutput]:
        """If we were to interface a real hardware, you would grab the sensor measurement here."""

        """
        As the .step method may be jit-compiled, it is important to wrap any side-effecting code in a host_callback.
        See the jax documentation for more information on how to do this:
        https://jax.readthedocs.io/en/latest/external-callbacks.html
        """
        world = step_state.inputs["world"][-1].data

        def _grab_measurement():
            """
            Not really doing anything here, just a dummy implementation.
            Include some side-effecting code here (e.g. grabbing measurement from sensor).

            The .step method may be jit-compiled, it is important to wrap any side-effecting code in a host_callback.
            See the jax documentation for more information on how to do this:
            https://jax.readthedocs.io/en/latest/external-callbacks.html
            """
            # print("Grabbing sensor measurement")
            sensor_msg = onp.array(1.0)  # Dummy sensor measurement (not actually used)
            return sensor_msg  # Must match dtype and shape of return_shape

        # Grab sensor measurement
        return_shape = jnp.array(1.0)  # Must match dtype and shape of return_shape
        _ = jax.experimental.io_callback(_grab_measurement, return_shape)

        # Prepare output
        output = SensorOutput(th=world.th, thdot=world.thdot)

        # Update state (NOOP)
        new_step_state = step_state

        return new_step_state, output

    def stop(self, timeout: float = None) -> bool:
        """Stopping routine that is called after the episode is done.

        **IMPORTANT** It may happen that stop is called *before* the final .step call of an episode returns,
        which may cause unsafe behavior when the final step undoes the work of the .stop method.
        This should be handled by the user. For example, by stopping "longer" before returning here.

        Only ran when running asynchronously.
        :param timeout: The timeout of the stop
        :return: Whether the node has stopped successfully.
        """
        return True  # Not doing anything here


class SimSensor(BaseNode):
    """This is a simple simulated sensor node definition that can either
    1. Convert the world state into a realistic sensor measurement (for normal operation, e.g., training).
       Optionally, you could include some noise or other modifications to the sensor measurement.
    2. Calculate a reconstruction loss based on the sensor measurement and the recorded sensor outputs.

    By calculating and aggregating the reconstruction loss here, we take time-scale differences and delays into account.
    """

    def __init__(self, *args, outputs: SensorOutput = None, **kwargs):
        """Initialize a simulated sensor for system identification.

        If outputs are provided, we will calculate the reconstruction loss based on the recorded sensor outputs.

        :param outputs: Recorded sensor Outputs to be used for system identification.
        """
        super().__init__(*args, **kwargs)
        self._outputs = outputs

    def init_params(self, rng: jax.Array = None, graph_state: GraphState = None) -> SensorParams:
        """Default params of the node."""
        sensor_delay = 0.05
        return SensorParams(sensor_delay=sensor_delay)

    def init_state(self, rng: jax.Array = None, graph_state: GraphState = None) -> SensorState:
        """Default state of the node."""
        return SensorState(loss_th=0.0, loss_thdot=0.0)  # Initialize reconstruction loss to zero at the start of the episode

    def init_output(self, rng: jax.Array = None, graph_state: GraphState = None) -> SensorOutput:
        """Default output of the node."""
        # Randomly define some initial sensor values
        th = jnp.pi
        thdot = 0.0
        return SensorOutput(th=th, thdot=thdot)  # Fix the initial sensor values

    def init_delays(
        self, rng: jax.Array = None, graph_state: base.GraphState = None
    ) -> Dict[str, Union[float, jax.typing.ArrayLike]]:
        """Initialize trainable communication delays.

        **Note** These only include trainable delays that were specified while connecting the nodes.

        :param rng: Random number generator.
        :param graph_state: The graph state that may be used to get the default output.
        :return: Trainable delays (e.g., {input_name: delay}). Can be an incomplete dictionary.
                 Entries for non-trainable delays or non-existent connections are ignored.
        """
        graph_state = graph_state or GraphState()
        params = graph_state.params.get(self.name, self.init_params(rng, graph_state))
        delays = {"world": params.sensor_delay}
        return delays

    def step(self, step_state: StepState) -> Tuple[StepState, SensorOutput]:
        # Determine output
        data = step_state.inputs["world"][-1].data
        output = SensorOutput(th=data.th, thdot=data.thdot)

        # Calculate loss
        if self._outputs is not None:  # Calculate reconstruction loss and aggregate in state
            output_rec = tree_dynamic_slice(self._outputs, jnp.array([step_state.eps, step_state.seq]))
            output_rec = jax.tree_util.tree_map(lambda _o: _o[0, 0], output_rec)
            th_rec, thdot_rec = output_rec.th, output_rec.thdot
            state = step_state.state
            loss_th = state.loss_th + (jnp.sin(output.th) - jnp.sin(th_rec)) ** 2 + (jnp.cos(output.th) - jnp.cos(th_rec)) ** 2
            loss_thdot = state.loss_thdot + (output.thdot - thdot_rec) ** 2
            new_state = state.replace(loss_th=loss_th, loss_thdot=loss_thdot)
        else:  # NOOP
            new_state = step_state.state

        # Update step_state
        new_step_state = step_state.replace(state=new_state)
        return new_step_state, output
# @title Example: ODE simulation node

from math import ceil
from typing import Dict, Tuple, Union

import jax
from flax import struct

from rex import base
from rex.base import GraphState, StepState
from rex.node import BaseWorld


@struct.dataclass
class OdeParams(base.Base):
    """Pendulum ode param definition"""

    max_speed: Union[float, jax.typing.ArrayLike]
    J: Union[float, jax.typing.ArrayLike]
    mass: Union[float, jax.typing.ArrayLike]
    length: Union[float, jax.typing.ArrayLike]
    b: Union[float, jax.typing.ArrayLike]
    K: Union[float, jax.typing.ArrayLike]
    R: Union[float, jax.typing.ArrayLike]
    c: Union[float, jax.typing.ArrayLike]
    dt_substeps_min: Union[float, jax.typing.ArrayLike] = struct.field(pytree_node=False)
    dt: Union[float, jax.typing.ArrayLike] = struct.field(pytree_node=False)

    @property
    def substeps(self) -> int:
        substeps = ceil(self.dt / self.dt_substeps_min)
        return int(substeps)

    @property
    def dt_substeps(self) -> float:
        substeps = self.substeps
        dt_substeps = self.dt / substeps
        return dt_substeps

    def step(
        self, substeps: int, dt_substeps: jax.typing.ArrayLike, x: "OdeState", us: jax.typing.ArrayLike
    ) -> Tuple["OdeState", "OdeState"]:
        """Step the pendulum ode."""

        def _scan_fn(_x, _u):
            next_x = self._runge_kutta4(dt_substeps, _x, _u)
            # Clip velocity
            clip_thdot = jnp.clip(next_x.thdot, -self.max_speed, self.max_speed)
            next_x = next_x.replace(thdot=clip_thdot)
            return next_x, next_x

        x_final, x_substeps = jax.lax.scan(_scan_fn, x, us, length=substeps)
        return x_final, x_substeps

    def _runge_kutta4(self, dt: jax.typing.ArrayLike, x: "OdeState", u: jax.typing.ArrayLike) -> "OdeState":
        k1 = self._ode(x, u)
        k2 = self._ode(x + k1 * dt * 0.5, u)
        k3 = self._ode(x + k2 * dt * 0.5, u)
        k4 = self._ode(x + k3 * dt, u)
        return x + (k1 + k2 * 2 + k3 * 2 + k4) * (dt / 6)

    def _ode(self, x: "OdeState", u: jax.typing.ArrayLike) -> "OdeState":
        """dx function for the pendulum ode"""
        # Downward := [pi, 0], Upward := [0, 0]
        g, J, m, l, b, K, R, c = 9.81, self.J, self.mass, self.length, self.b, self.K, self.R, self.c  # noqa: E741
        th, thdot = x.th, x.thdot
        activation = jnp.sign(thdot)
        ddx = (u * K / R + m * g * l * jnp.sin(th) - b * thdot - thdot * K * K / R - c * activation) / J
        return OdeState(th=thdot, thdot=ddx, loss_task=0.0)  # No derivative for loss_task


@struct.dataclass
class OdeState(base.Base):
    """Pendulum state definition"""

    loss_task: Union[float, jax.typing.ArrayLike]
    th: Union[float, jax.typing.ArrayLike]
    thdot: Union[float, jax.typing.ArrayLike]


@struct.dataclass
class OdeOutput(base.Base):
    """World output definition"""

    th: Union[float, jax.typing.ArrayLike]
    thdot: Union[float, jax.typing.ArrayLike]


class OdeWorld(BaseWorld):  # We inherit from BaseWorld for convenience, but you can inherit from BaseNode if you want
    def init_params(self, rng: jax.Array = None, graph_state: GraphState = None) -> OdeParams:
        """Default params of the node."""
        return OdeParams(
            max_speed=40.0,  # Clip angular velocity to this value
            J=0.000159931461600856,  # 0.000159931461600856,
            mass=0.0508581731919534,  # 0.0508581731919534,
            length=0.0415233722862552,  # 0.0415233722862552,
            b=1.43298488e-05,  # 1.43298488358436e-05,
            K=0.03333912,  # 0.0333391179016334,
            R=7.73125142,  # 7.73125142447252,
            c=0.000975041213361349,  # 0.000975041213361349,
            # Backend parameters
            dt_substeps_min=1 / 100,  # Minimum substep size for ode integration
            dt=1 / self.rate,  # Time step per .step() call
        )

    def init_state(self, rng: jax.Array = None, graph_state: GraphState = None) -> OdeState:
        """Default state of the node."""
        graph_state = graph_state or GraphState()

        # Try to grab state from graph_state
        state = graph_state.state.get("agent", None)
        init_th = state.init_th if state is not None else jnp.pi
        init_thdot = state.init_thdot if state is not None else 0.0
        return OdeState(th=init_th, thdot=init_thdot, loss_task=0.0)

    def init_output(self, rng: jax.Array = None, graph_state: GraphState = None) -> OdeOutput:
        """Default output of the node."""
        graph_state = graph_state or GraphState()
        # Grab output from state
        world_state = graph_state.state.get(self.name, self.init_state(rng, graph_state))
        return OdeOutput(th=world_state.th, thdot=world_state.thdot)

    def init_delays(
        self, rng: jax.Array = None, graph_state: base.GraphState = None
    ) -> Dict[str, Union[float, jax.typing.ArrayLike]]:
        graph_state = graph_state or GraphState()
        params = graph_state.params.get("actuator")
        delays = {}
        if hasattr(params, "actuator_delay"):
            delays["actuator"] = params.actuator_delay
        return delays

    def step(self, step_state: StepState) -> Tuple[StepState, OdeOutput]:
        """Step the node."""
        # Unpack StepState
        _, state, params, inputs = step_state.rng, step_state.state, step_state.params, step_state.inputs

        # Apply dynamics
        u = inputs["actuator"].data.action[-1][0]  # [-1] to get the latest action, [0] reduces the dimension to scalar
        us = jnp.array([u] * params.substeps)
        new_state = params.step(params.substeps, params.dt_substeps, state, us)[0]
        next_th, next_thdot = new_state.th, new_state.thdot
        output = OdeOutput(th=next_th, thdot=next_thdot)  # Prepare output

        # Calculate cost (penalize angle error, angular velocity and input voltage)
        norm_next_th = next_th - 2 * jnp.pi * jnp.floor((next_th + jnp.pi) / (2 * jnp.pi))
        loss_task = state.loss_task + norm_next_th**2 + 0.1 * (next_thdot / (1 + 10 * abs(norm_next_th))) ** 2 + 0.01 * u**2

        # Update state
        new_state = new_state.replace(loss_task=loss_task)
        new_step_state = step_state.replace(state=new_state)
        return new_step_state, output
# @title Example: Brax simulation node
from typing import Tuple, Union

import jax
from flax import struct

from rex import base
from rex.base import GraphState, StepState
from rex.node import BaseWorld


try:
    from brax.generalized import pipeline as gen_pipeline
    from brax.io import mjcf
    from brax.positional import pipeline as pos_pipeline
    from brax.spring import pipeline as spring_pipeline

    Systems = Union[gen_pipeline.System, spring_pipeline.System, pos_pipeline.System]
    Pipelines = Union[gen_pipeline.State, spring_pipeline.State, pos_pipeline.State]
except ModuleNotFoundError as e:
    print("Brax not installed. Install it with `pip install brax`")
    raise e


@struct.dataclass
class BraxParams(base.Base):
    max_speed: Union[float, jax.typing.ArrayLike]
    damping: Union[float, jax.typing.ArrayLike]
    armature: Union[float, jax.typing.ArrayLike]
    gear: Union[float, jax.typing.ArrayLike]
    mass_weight: Union[float, jax.typing.ArrayLike]
    radius_weight: Union[float, jax.typing.ArrayLike]
    offset: Union[float, jax.typing.ArrayLike]
    friction_loss: Union[float, jax.typing.ArrayLike]
    backend: str = struct.field(pytree_node=False)
    dt: Union[float, jax.typing.ArrayLike] = struct.field(pytree_node=False)

    @property
    def substeps(self) -> int:
        dt_substeps_per_backend = {"generalized": 1 / 100, "spring": 1 / 100, "positional": 1 / 100}[self.backend]
        substeps = ceil(self.dt / dt_substeps_per_backend)
        return int(substeps)

    @property
    def dt_substeps(self) -> float:
        substeps = self.substeps
        dt_substeps = self.dt / substeps
        return dt_substeps

    @property
    def pipeline(self) -> Pipelines:
        return {"generalized": gen_pipeline, "spring": spring_pipeline, "positional": pos_pipeline}[self.backend]

    @property
    def sys(self) -> Systems:
        base_sys = mjcf.loads(DISK_PENDULUM_XML)
        # Appropriately replace parameters for the disk pendulum
        itransform = base_sys.link.inertia.transform.replace(pos=jnp.array([[0.0, self.offset, 0.0]]))
        i = base_sys.link.inertia.i.at[0, 0, 0].set(
            0.5 * self.mass_weight * self.radius_weight**2
        )  # inertia of cylinder in local frame.
        inertia = base_sys.link.inertia.replace(transform=itransform, mass=jnp.array([self.mass_weight]), i=i)
        link = base_sys.link.replace(inertia=inertia)
        actuator = base_sys.actuator.replace(gear=jnp.array([self.gear]))
        dof = base_sys.dof.replace(armature=jnp.array([self.armature]), damping=jnp.array([self.damping]))
        opt = base_sys.opt.replace(timestep=self.dt_substeps)
        new_sys = base_sys.replace(link=link, actuator=actuator, dof=dof, opt=opt)
        return new_sys

    def step(
        self, substeps: int, dt_substeps: jax.typing.ArrayLike, x: Pipelines, us: jax.typing.ArrayLike
    ) -> Tuple[Pipelines, Pipelines]:
        """Step the pendulum ode."""
        # Appropriately replace timestep for the disk pendulum
        sys = self.sys.replace(opt=self.sys.opt.replace(timestep=dt_substeps))

        def _scan_fn(_x, _u):
            # Add friction loss
            thdot = x.qd[0]
            activation = jnp.sign(thdot)
            friction = self.friction_loss * activation / sys.actuator.gear[0]
            _u_friction = _u - friction
            # Step
            next_x = gen_pipeline.step(sys, _x, jnp.array(_u_friction)[None])
            # Clip velocity
            next_x = next_x.replace(qd=jnp.clip(next_x.qd, -self.max_speed, self.max_speed))
            return next_x, next_x

        x_final, x_substeps = jax.lax.scan(_scan_fn, x, us, length=substeps)
        return x_final, x_substeps


@struct.dataclass
class BraxState(base.Base):
    """Pendulum state definition"""

    loss_task: Union[float, jax.typing.ArrayLike]
    pipeline_state: Pipelines

    @property
    def th(self):
        return self.pipeline_state.q[..., 0]

    @property
    def thdot(self):
        return self.pipeline_state.qd[..., 0]


@struct.dataclass
class BraxOutput(base.Base):
    """World output definition"""

    th: Union[float, jax.typing.ArrayLike]
    thdot: Union[float, jax.typing.ArrayLike]


class BraxWorld(BaseWorld):  # We inherit from BaseWorld for convenience, but you can inherit from BaseNode if you want
    def __init__(self, *args, backend: str = "generalized", **kwargs):
        super().__init__(*args, **kwargs)
        self.backend = backend

    def init_params(self, rng: jax.Array = None, graph_state: GraphState = None) -> BraxParams:
        """Default params of the node."""
        return BraxParams(
            # Realistic parameters for the disk pendulum
            max_speed=40.0,
            damping=0.00015877,
            armature=6.4940527e-06,
            gear=0.00428677,
            mass_weight=0.05076142,
            radius_weight=0.05121992,
            offset=0.04161447,
            friction_loss=0.00097525,
            # Backend parameters
            dt=1 / self.rate,
            backend=self.backend,
        )

    def init_state(self, rng: jax.Array = None, graph_state: GraphState = None) -> BraxState:
        """Default state of the node."""
        graph_state = graph_state or GraphState()

        # Try to grab state from graph_state
        state = graph_state.state.get("agent", None)
        init_th = state.init_th if state is not None else jnp.pi
        init_thdot = state.init_thdot if state is not None else 0.0

        # Set the initial state of the disk pendulum
        params = graph_state.params.get(self.name, self.init_params(rng, graph_state))
        sys = params.sys
        q = sys.init_q.at[0].set(init_th)
        qd = jnp.array([init_thdot])
        pipeline_state = params.pipeline.init(sys, q, qd)
        return BraxState(pipeline_state=pipeline_state, loss_task=0.0)

    def init_output(self, rng: jax.Array = None, graph_state: GraphState = None) -> BraxOutput:
        """Default output of the node."""
        graph_state = graph_state or GraphState()
        # Grab output from state
        state = graph_state.state.get(self.name, self.init_state(rng, graph_state))
        return BraxOutput(th=state.pipeline_state.q[0], thdot=state.pipeline_state.qd[0])

    def step(self, step_state: StepState) -> Tuple[StepState, BraxOutput]:
        """Step the node."""

        # Unpack StepState
        _, state, params, inputs = step_state.rng, step_state.state, step_state.params, step_state.inputs

        # Apply dynamics
        u = inputs["actuator"].data.action[-1][0]  # [-1] to get the latest action, [0] reduces the dimension to scalar
        us = jnp.array([u] * params.substeps)
        x = state.pipeline_state
        next_x = params.step(params.substeps, params.dt_substeps, x, us)[0]
        new_state = state.replace(pipeline_state=next_x)
        next_th, next_thdot = new_state.th, new_state.thdot
        output = BraxOutput(th=next_th, thdot=next_thdot)  # Prepare output

        # Calculate cost (penalize angle error, angular velocity and input voltage)
        norm_next_th = next_th - 2 * jnp.pi * jnp.floor((next_th + jnp.pi) / (2 * jnp.pi))
        loss_task = state.loss_task + norm_next_th**2 + 0.1 * (next_thdot / (1 + 10 * abs(norm_next_th))) ** 2 + 0.01 * u**2

        # Update state
        new_state = new_state.replace(loss_task=loss_task)
        new_step_state = step_state.replace(state=new_state)
        return new_step_state, output


DISK_PENDULUM_XML = """
<mujoco model="disk_pendulum">
    <compiler inertiafromgeom="auto" angle="radian" coordinate="local" eulerseq="xyz" autolimits="true"/>
    <option gravity="0 0 -9.81" timestep="0.01" iterations="10"/>
    <custom>
        <numeric data="10" name="constraint_ang_damping"/> <!-- positional & spring -->
        <numeric data="1" name="spring_inertia_scale"/>  <!-- positional & spring -->
        <numeric data="0" name="ang_damping"/>  <!-- positional & spring -->
        <numeric data="0" name="spring_mass_scale"/>  <!-- positional & spring -->
        <numeric data="0.5" name="joint_scale_pos"/> <!-- positional -->
        <numeric data="0.1" name="joint_scale_ang"/> <!-- positional -->
        <numeric data="3000" name="constraint_stiffness"/>  <!-- spring -->
        <numeric data="10000" name="constraint_limit_stiffness"/>  <!-- spring -->
        <numeric data="50" name="constraint_vel_damping"/>  <!-- spring -->
        <numeric data="10" name="solver_maxls"/>  <!-- generalized -->
    </custom>

    <asset>
        <texture builtin="flat" height="1278" mark="cross" markrgb="1 1 1" name="texgeom" random="0.01" rgb1="0.8 0.6 0.4" rgb2="0.8 0.6 0.4" type="cube" width="127"/>
        <material name="geom" texture="texgeom" texuniform="true"/>
    </asset>

    <default>
        <geom contype="0" friction="1 0.1 0.1" material="geom"/>
    </default>

    <worldbody>
        <light cutoff="100" diffuse="1 1 1" dir="-0 0 -1.3" directional="true" exponent="1" pos="0 0 1.3" specular=".1 .1 .1"/>
        <geom name="table" type="plane" pos="0 0.0 -0.1" size="1 1 0.1" contype="8" conaffinity="11" condim="3"/>
        <body name="disk" pos="0.0 0.0 0.0" euler="1.5708 0.0 0.0">
            <joint name="hinge_joint" type="hinge" axis="0 0 1" range="-180 180" armature="0.00022993" damping="0.0001" limited="false"/>
            <geom name="disk_geom" type="cylinder" size="0.06 0.001" contype="0" conaffinity="0" condim="3" mass="0.0"/>
            <geom name="mass_geom" type="cylinder" size="0.02 0.005" contype="0" conaffinity="0"  condim="3" rgba="0.04 0.04 0.04 1"
                  pos="0.0 0.04 0." mass="0.05085817"/>
        </body>
    </worldbody>

    <actuator>
        <motor joint="hinge_joint" ctrllimited="false" ctrlrange="-3.0 3.0"  gear="0.01"/>
    </actuator>
</mujoco>
"""