Skip to content

Defining Graphs and Environments in rex (Robotic Environments with jaX)

¤

This notebook offers an introductory tutorial for rex (Robotic Environments with jaX), a JAX-based framework for creating graph-based environments tailored for sim2real robotics.

In this tutorial, we will guide you through the process of defining graphs and environments. Specifically, we will demonstrate how to define the nodes and the training environment used in the sim2real.ipynb notebook.

# @title Install Necessary Libraries
# @markdown This cell installs the required libraries for the project.
# @markdown If you are running this notebook in Google Colab, most libraries should already be installed.

try:
    import rex

    print("Rex already installed")
except ImportError:
    print(
        "Installing rex via `pip install rex-lib[examples]`. "
        "If you are running this in a Colab notebook, you can ignore this message."
    )
    !pip install rex-lib[examples]
    import rex

# Check if we have a GPU
import itertools

import jax


try:
    gpu = jax.devices("gpu")
    gpu = gpu[0] if len(gpu) > 0 else None
    print("GPU found!")
except RuntimeError:
    print("Warning: No GPU found, falling back to CPU. Speedups will be less pronounced.")
    print(
        "Hint: if you are using Google Colab, try to change the runtime to GPU: "
        "Runtime -> Change runtime type -> Hardware accelerator -> GPU."
    )
    gpu = None

# Check the number of available CPU cores
print(f"CPU cores available: {len(jax.devices('cpu'))}")
cpus = itertools.cycle(jax.devices("cpu"))

# Set plot settings
import seaborn as sns


sns.set()
Installing rex via `pip install rex-lib[examples]`. If you are running this in a Colab notebook, you can ignore this message.
Collecting rex-lib[examples]
  Downloading rex_lib-0.0.5-py3-none-any.whl.metadata (15 kB)
Collecting dill>=0.3.8 (from rex-lib[examples])
  Downloading dill-0.3.9-py3-none-any.whl.metadata (10 kB)
Collecting distrax>=0.1.5 (from rex-lib[examples])
  Downloading distrax-0.1.5-py3-none-any.whl.metadata (13 kB)
Collecting equinox>=0.11.4 (from rex-lib[examples])
  Downloading equinox-0.11.7-py3-none-any.whl.metadata (18 kB)
Collecting evosax>=0.1.6 (from rex-lib[examples])
  Downloading evosax-0.1.6-py3-none-any.whl.metadata (26 kB)
Requirement already satisfied: flax>=0.8.5 in /usr/local/lib/python3.10/dist-packages (from rex-lib[examples]) (0.8.5)
Collecting gymnasium>=0.29.1 (from rex-lib[examples])
  Downloading gymnasium-1.0.0-py3-none-any.whl.metadata (9.5 kB)
Requirement already satisfied: jax>=0.4.30 in /usr/local/lib/python3.10/dist-packages (from rex-lib[examples]) (0.4.33)
Requirement already satisfied: matplotlib>=3.7.0 in /usr/local/lib/python3.10/dist-packages (from rex-lib[examples]) (3.7.1)
Requirement already satisfied: networkx>=3.2.1 in /usr/local/lib/python3.10/dist-packages (from rex-lib[examples]) (3.3)
Requirement already satisfied: optax>=0.2.3 in /usr/local/lib/python3.10/dist-packages (from rex-lib[examples]) (0.2.3)
Collecting seaborn>=0.13.2 (from rex-lib[examples])
  Downloading seaborn-0.13.2-py3-none-any.whl.metadata (5.4 kB)
Collecting supergraph>=0.0.8 (from rex-lib[examples])
  Downloading supergraph-0.0.8-py3-none-any.whl.metadata (1.2 kB)
Requirement already satisfied: termcolor>=2.4.0 in /usr/local/lib/python3.10/dist-packages (from rex-lib[examples]) (2.4.0)
Requirement already satisfied: tqdm>=4.66.4 in /usr/local/lib/python3.10/dist-packages (from rex-lib[examples]) (4.66.5)
Collecting brax>=0.10.5 (from rex-lib[examples])
  Downloading brax-0.11.0-py3-none-any.whl.metadata (7.7 kB)
Requirement already satisfied: absl-py in /usr/local/lib/python3.10/dist-packages (from brax>=0.10.5->rex-lib[examples]) (1.4.0)
Collecting dm-env (from brax>=0.10.5->rex-lib[examples])
  Downloading dm_env-1.6-py3-none-any.whl.metadata (966 bytes)
Requirement already satisfied: etils in /usr/local/lib/python3.10/dist-packages (from brax>=0.10.5->rex-lib[examples]) (1.9.4)
Requirement already satisfied: flask in /usr/local/lib/python3.10/dist-packages (from brax>=0.10.5->rex-lib[examples]) (2.2.5)
Collecting flask-cors (from brax>=0.10.5->rex-lib[examples])
  Downloading Flask_Cors-5.0.0-py2.py3-none-any.whl.metadata (5.5 kB)
Requirement already satisfied: grpcio in /usr/local/lib/python3.10/dist-packages (from brax>=0.10.5->rex-lib[examples]) (1.64.1)
Requirement already satisfied: gym in /usr/local/lib/python3.10/dist-packages (from brax>=0.10.5->rex-lib[examples]) (0.25.2)
Requirement already satisfied: jaxlib>=0.4.6 in /usr/local/lib/python3.10/dist-packages (from brax>=0.10.5->rex-lib[examples]) (0.4.33)
Collecting jaxopt (from brax>=0.10.5->rex-lib[examples])
  Downloading jaxopt-0.8.3-py3-none-any.whl.metadata (2.6 kB)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from brax>=0.10.5->rex-lib[examples]) (3.1.4)
Collecting ml-collections (from brax>=0.10.5->rex-lib[examples])
  Downloading ml_collections-0.1.1.tar.gz (77 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 77.9/77.9 kB 2.9 MB/s eta 0:00:00
  Preparing metadata (setup.py) ... done
Collecting mujoco (from brax>=0.10.5->rex-lib[examples])
  Downloading mujoco-3.2.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (44 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 44.4/44.4 kB 1.2 MB/s eta 0:00:00
Collecting mujoco-mjx (from brax>=0.10.5->rex-lib[examples])
  Downloading mujoco_mjx-3.2.3-py3-none-any.whl.metadata (3.4 kB)
Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from brax>=0.10.5->rex-lib[examples]) (1.26.4)
Requirement already satisfied: orbax-checkpoint in /usr/local/lib/python3.10/dist-packages (from brax>=0.10.5->rex-lib[examples]) (0.6.4)
Requirement already satisfied: Pillow in /usr/local/lib/python3.10/dist-packages (from brax>=0.10.5->rex-lib[examples]) (10.4.0)
Collecting pytinyrenderer (from brax>=0.10.5->rex-lib[examples])
  Downloading pytinyrenderer-0.0.14-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.2 kB)
Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from brax>=0.10.5->rex-lib[examples]) (1.13.1)
Collecting tensorboardX (from brax>=0.10.5->rex-lib[examples])
  Downloading tensorboardX-2.6.2.2-py2.py3-none-any.whl.metadata (5.8 kB)
Collecting trimesh (from brax>=0.10.5->rex-lib[examples])
  Downloading trimesh-4.4.9-py3-none-any.whl.metadata (18 kB)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from brax>=0.10.5->rex-lib[examples]) (4.12.2)
Requirement already satisfied: chex>=0.1.8 in /usr/local/lib/python3.10/dist-packages (from distrax>=0.1.5->rex-lib[examples]) (0.1.87)
Requirement already satisfied: tensorflow-probability>=0.15.0 in /usr/local/lib/python3.10/dist-packages (from distrax>=0.1.5->rex-lib[examples]) (0.24.0)
Collecting jaxtyping>=0.2.20 (from equinox>=0.11.4->rex-lib[examples])
  Downloading jaxtyping-0.2.34-py3-none-any.whl.metadata (6.4 kB)
Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from evosax>=0.1.6->rex-lib[examples]) (6.0.2)
Collecting dotmap (from evosax>=0.1.6->rex-lib[examples])
  Downloading dotmap-1.3.30-py3-none-any.whl.metadata (3.2 kB)
Requirement already satisfied: msgpack in /usr/local/lib/python3.10/dist-packages (from flax>=0.8.5->rex-lib[examples]) (1.0.8)
Requirement already satisfied: tensorstore in /usr/local/lib/python3.10/dist-packages (from flax>=0.8.5->rex-lib[examples]) (0.1.66)
Requirement already satisfied: rich>=11.1 in /usr/local/lib/python3.10/dist-packages (from flax>=0.8.5->rex-lib[examples]) (13.9.1)
Requirement already satisfied: cloudpickle>=1.2.0 in /usr/local/lib/python3.10/dist-packages (from gymnasium>=0.29.1->rex-lib[examples]) (2.2.1)
Collecting farama-notifications>=0.0.1 (from gymnasium>=0.29.1->rex-lib[examples])
  Downloading Farama_Notifications-0.0.4-py3-none-any.whl.metadata (558 bytes)
Requirement already satisfied: ml-dtypes>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from jax>=0.4.30->rex-lib[examples]) (0.4.1)
Requirement already satisfied: opt-einsum in /usr/local/lib/python3.10/dist-packages (from jax>=0.4.30->rex-lib[examples]) (3.4.0)
Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.7.0->rex-lib[examples]) (1.3.0)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.7.0->rex-lib[examples]) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.7.0->rex-lib[examples]) (4.54.1)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.7.0->rex-lib[examples]) (1.4.7)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.7.0->rex-lib[examples]) (24.1)
Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.7.0->rex-lib[examples]) (3.1.4)
Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.7.0->rex-lib[examples]) (2.8.2)
Requirement already satisfied: pandas>=1.2 in /usr/local/lib/python3.10/dist-packages (from seaborn>=0.13.2->rex-lib[examples]) (2.2.2)
Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.10/dist-packages (from chex>=0.1.8->distrax>=0.1.5->rex-lib[examples]) (0.12.1)
Collecting typeguard==2.13.3 (from jaxtyping>=0.2.20->equinox>=0.11.4->rex-lib[examples])
  Downloading typeguard-2.13.3-py3-none-any.whl.metadata (3.6 kB)
Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.2->seaborn>=0.13.2->rex-lib[examples]) (2024.2)
Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.2->seaborn>=0.13.2->rex-lib[examples]) (2024.2)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.7->matplotlib>=3.7.0->rex-lib[examples]) (1.16.0)
Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich>=11.1->flax>=0.8.5->rex-lib[examples]) (3.0.0)
Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich>=11.1->flax>=0.8.5->rex-lib[examples]) (2.18.0)
Requirement already satisfied: decorator in /usr/local/lib/python3.10/dist-packages (from tensorflow-probability>=0.15.0->distrax>=0.1.5->rex-lib[examples]) (4.4.2)
Requirement already satisfied: gast>=0.3.2 in /usr/local/lib/python3.10/dist-packages (from tensorflow-probability>=0.15.0->distrax>=0.1.5->rex-lib[examples]) (0.6.0)
Requirement already satisfied: dm-tree in /usr/local/lib/python3.10/dist-packages (from tensorflow-probability>=0.15.0->distrax>=0.1.5->rex-lib[examples]) (0.1.8)
Requirement already satisfied: Werkzeug>=2.2.2 in /usr/local/lib/python3.10/dist-packages (from flask->brax>=0.10.5->rex-lib[examples]) (3.0.4)
Requirement already satisfied: itsdangerous>=2.0 in /usr/local/lib/python3.10/dist-packages (from flask->brax>=0.10.5->rex-lib[examples]) (2.2.0)
Requirement already satisfied: click>=8.0 in /usr/local/lib/python3.10/dist-packages (from flask->brax>=0.10.5->rex-lib[examples]) (8.1.7)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->brax>=0.10.5->rex-lib[examples]) (2.1.5)
Requirement already satisfied: gym-notices>=0.0.4 in /usr/local/lib/python3.10/dist-packages (from gym->brax>=0.10.5->rex-lib[examples]) (0.0.8)
Requirement already satisfied: contextlib2 in /usr/local/lib/python3.10/dist-packages (from ml-collections->brax>=0.10.5->rex-lib[examples]) (21.6.0)
Collecting glfw (from mujoco->brax>=0.10.5->rex-lib[examples])
  Downloading glfw-2.7.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38-none-manylinux2014_x86_64.whl.metadata (5.4 kB)
Requirement already satisfied: pyopengl in /usr/local/lib/python3.10/dist-packages (from mujoco->brax>=0.10.5->rex-lib[examples]) (3.1.7)
Requirement already satisfied: nest_asyncio in /usr/local/lib/python3.10/dist-packages (from orbax-checkpoint->brax>=0.10.5->rex-lib[examples]) (1.6.0)
Requirement already satisfied: protobuf in /usr/local/lib/python3.10/dist-packages (from orbax-checkpoint->brax>=0.10.5->rex-lib[examples]) (3.20.3)
Requirement already satisfied: humanize in /usr/local/lib/python3.10/dist-packages (from orbax-checkpoint->brax>=0.10.5->rex-lib[examples]) (4.10.0)
Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich>=11.1->flax>=0.8.5->rex-lib[examples]) (0.1.2)
Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from etils->brax>=0.10.5->rex-lib[examples]) (2024.6.1)
Requirement already satisfied: importlib_resources in /usr/local/lib/python3.10/dist-packages (from etils->brax>=0.10.5->rex-lib[examples]) (6.4.5)
Requirement already satisfied: zipp in /usr/local/lib/python3.10/dist-packages (from etils->brax>=0.10.5->rex-lib[examples]) (3.20.2)
Downloading brax-0.11.0-py3-none-any.whl (998 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 998.6/998.6 kB 17.0 MB/s eta 0:00:00
Downloading dill-0.3.9-py3-none-any.whl (119 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 119.4/119.4 kB 6.5 MB/s eta 0:00:00
Downloading distrax-0.1.5-py3-none-any.whl (319 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 319.7/319.7 kB 13.9 MB/s eta 0:00:00
Downloading equinox-0.11.7-py3-none-any.whl (178 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 178.4/178.4 kB 9.3 MB/s eta 0:00:00
Downloading evosax-0.1.6-py3-none-any.whl (240 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 240.4/240.4 kB 13.8 MB/s eta 0:00:00
Downloading gymnasium-1.0.0-py3-none-any.whl (958 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 958.1/958.1 kB 22.7 MB/s eta 0:00:00
Downloading seaborn-0.13.2-py3-none-any.whl (294 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 294.9/294.9 kB 18.2 MB/s eta 0:00:00
Downloading supergraph-0.0.8-py3-none-any.whl (65 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 65.5/65.5 kB 2.1 MB/s eta 0:00:00
Downloading rex_lib-0.0.5-py3-none-any.whl (115 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 115.1/115.1 kB 8.3 MB/s eta 0:00:00
Downloading Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)
Downloading jaxtyping-0.2.34-py3-none-any.whl (42 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 42.4/42.4 kB 605.9 kB/s eta 0:00:00
Downloading typeguard-2.13.3-py3-none-any.whl (17 kB)
Downloading dm_env-1.6-py3-none-any.whl (26 kB)
Downloading dotmap-1.3.30-py3-none-any.whl (11 kB)
Downloading Flask_Cors-5.0.0-py2.py3-none-any.whl (14 kB)
Downloading jaxopt-0.8.3-py3-none-any.whl (172 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 172.3/172.3 kB 7.5 MB/s eta 0:00:00
Downloading mujoco-3.2.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (6.1 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.1/6.1 MB 43.5 MB/s eta 0:00:00
Downloading mujoco_mjx-3.2.3-py3-none-any.whl (6.7 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.7/6.7 MB 23.4 MB/s eta 0:00:00
Downloading pytinyrenderer-0.0.14-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.9 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.9/1.9 MB 20.8 MB/s eta 0:00:00
Downloading tensorboardX-2.6.2.2-py2.py3-none-any.whl (101 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 101.7/101.7 kB 3.8 MB/s eta 0:00:00
Downloading trimesh-4.4.9-py3-none-any.whl (700 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 700.1/700.1 kB 15.6 MB/s eta 0:00:00
Downloading glfw-2.7.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38-none-manylinux2014_x86_64.whl (211 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 211.8/211.8 kB 15.1 MB/s eta 0:00:00
Building wheels for collected packages: ml-collections
  Building wheel for ml-collections (setup.py) ... done
  Created wheel for ml-collections: filename=ml_collections-0.1.1-py3-none-any.whl size=94507 sha256=c2ba0db03ffefa350aba3215509ced8bfaf78a8937ceaf29ffbf3655f21c333c
  Stored in directory: /root/.cache/pip/wheels/7b/89/c9/a9b87790789e94aadcfc393c283e3ecd5ab916aed0a31be8fe
Successfully built ml-collections
Installing collected packages: pytinyrenderer, glfw, farama-notifications, dotmap, typeguard, trimesh, tensorboardX, supergraph, ml-collections, gymnasium, dm-env, dill, jaxtyping, seaborn, mujoco, flask-cors, mujoco-mjx, jaxopt, equinox, distrax, evosax, brax, rex-lib
  Attempting uninstall: typeguard
    Found existing installation: typeguard 4.3.0
    Uninstalling typeguard-4.3.0:
      Successfully uninstalled typeguard-4.3.0
  Attempting uninstall: seaborn
    Found existing installation: seaborn 0.13.1
    Uninstalling seaborn-0.13.1:
      Successfully uninstalled seaborn-0.13.1
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
inflect 7.4.0 requires typeguard>=4.0.1, but you have typeguard 2.13.3 which is incompatible.
Successfully installed brax-0.11.0 dill-0.3.9 distrax-0.1.5 dm-env-1.6 dotmap-1.3.30 equinox-0.11.7 evosax-0.1.6 farama-notifications-0.0.4 flask-cors-5.0.0 glfw-2.7.0 gymnasium-1.0.0 jaxopt-0.8.3 jaxtyping-0.2.34 ml-collections-0.1.1 mujoco-3.2.3 mujoco-mjx-3.2.3 pytinyrenderer-0.0.14 rex-lib-0.0.5 seaborn-0.13.2 supergraph-0.0.8 tensorboardX-2.6.2.2 trimesh-4.4.9 typeguard-2.13.3
GPU found!
CPU cores available: 1

Introduction to Graphs and Environments in Rex¤

In Rex, a graph represents the interconnected structure of nodes, defining how data flows and computations are organized within a system. By assembling nodes into a graph, you can model complex systems that reflect real-world interactions or simulations. This section introduces how to define a graph using a set of nodes, interact with it using various APIs, understand the role of the supervisor node, and specify environments that interact with the graph.

# @title Defining a Graph from Nodes
# @markdown First, you need to define the nodes that will make up your graph.
# @markdown These nodes represent different components of a system, such as sensors, agents, actuators, and the world.
# @markdown **Note**: The `delay_dist` parameter is used to simulate computation delays, which is useful when modeling real-world systems.

# Import necessary modules and node classes
from distrax import Normal

import rex.examples.pendulum as pdm


# Instantiate nodes with their respective parameters
sensor = pdm.SimSensor(name="sensor", rate=50, color="pink", order=1, delay_dist=Normal(loc=0.0075, scale=0.003))
agent = pdm.Agent(
    name="agent", rate=50, color="teal", order=3, delay_dist=Normal(loc=0.01, scale=0.003)
)  # Computation delay of the agent
actuator = pdm.SimActuator(
    name="actuator", rate=50, color="orange", order=2, delay_dist=Normal(loc=0.0075, scale=0.003)
)  # Computation delay of the actuator
world = pdm.OdeWorld(name="world", rate=50, color="grape", order=0)  # Brax world that simulates the pendulum
nodes = dict(world=world, sensor=sensor, agent=agent, actuator=actuator)
# @title Connecting Nodes
# @markdown Now, we establish connections between the nodes using the `connect` method.
# @markdown - **`window`**: Determines how many past messages are stored and accessible in the input buffer.
# @markdown - **`blocking`**: If `True`, the receiving node waits for the input before proceeding.
# @markdown - **`skip`**: Used to resolve cyclic dependencies by skipping the connection when messages arrive simultaneously.

# Agent receives data from the sensor
agent.connect(
    output_node=sensor,
    window=3,  # Use the last three sensor messages
    name="sensor",  # Input name in the agent
    blocking=True,  # Wait for the sensor data before proceeding
    delay_dist=Normal(loc=0.002, scale=0.002),
)

# Actuator receives commands from the agent
actuator.connect(
    output_node=agent,
    window=1,  # Use the most recent action
    name="agent",
    blocking=True,
    delay_dist=Normal(loc=0.002, scale=0.002),
)

# World receives actions from the actuator
world.connect(
    output_node=actuator,
    window=1,
    name="actuator",
    # Resolve cyclic dependency world->sensor->agent->actuator->world
    skip=True,
    blocking=False,  # Non-blocking connection (i.e. world does not wait for actuator)
    delay_dist=Normal(loc=0.01, scale=0.002),
)

# Sensor receives state updates from the world
sensor.connect(
    output_node=world,
    window=1,
    name="world",
    blocking=False,  # Non-blocking connection (i.e. sensor does not wait for world)
    delay_dist=Normal(loc=0.01, scale=0.002),
)
# @title Visualizing the System
# @markdown You can visualize the system to understand the structure of your graph.

import matplotlib.pyplot as plt

from rex.utils import plot_system


# Collect node information for visualization
node_infos = {node.name: node.info for node in [sensor, agent, actuator, world]}

# Plot the system
fig, ax = plt.subplots(figsize=(8, 3))
plot_system(node_infos, ax=ax)
ax.legend()
ax.set_title("System Structure")
plt.show()

Graphs in Rex (SIMULATED and WALL_CLOCK runtimes)¤

In Rex, a graph is created by connecting nodes to define the flow of data and execution between them. A graph serves as the backbone for modeling systems that involve multiple interacting components, such as sensors, actuators, and agents.

Key Components of a Graph¤

  • nodes: The nodes that form the building blocks of the graph, each performing specific tasks like sensing, acting, or controlling.
  • supervisor: A designated node that determines the step-by-step progression of the graph (more details in the next section).
  • clock: Determines how time is managed in the graph. Choices include Clock.SIMULATED for virtual simulations and Clock.WALL_CLOCK for real-time applications.
  • real_time_factor: Sets the speed of the simulation. It can simulate as fast as possible (RealTimeFactor.FAST_AS_POSSIBLE), in real-time (RealTimeFactor.REAL_TIME), or at any custom speed relative to real-time.

Real-Time and Simulated Clocks¤

Rex provides flexible control over the simulation's timing through two main clock types and the ability to adjust the real-time factor: 1. Clock.SIMULATED: The simulation advances based on the specified delays between nodes. This mode is ideal for running simulations in a controlled environment. 2. Clock.WALL_CLOCK: The graph progresses based on real-world time. This mode is essential for real-time systems and deployments.

Controlling Simulation Speed with real_time_factor¤

The real_time_factor modifies the simulation speed: - RealTimeFactor.FAST_AS_POSSIBLE: Simulates as quickly as the system allows, constrained only by computational limits. - RealTimeFactor.REAL_TIME: Simulates in real-time, matching the speed of real-world processes. Combine with Clock.WALL_CLOCK for real-time applications. - Custom Speed: Any positive float value allows for custom speeds relative to real-time.

The Role of the Supervisor Node¤

A critical aspect of graph design in Rex is selecting a supervisor node, which dictates the execution flow. The supervisor node plays a pivotal role in controlling the step-by-step progression of the graph and can alter the perspective from which the system is viewed.

As a mental model, it helps to think of the graph as dividing the nodes into two groups: 1. Supervisor Node: The designated node that controls the graph's execution flow. 2. All Other Nodes: These nodes form the environment the supervisor interacts with.

This partitioning of nodes essentially creates an agent-environment interface, where the supervisor node acts as the agent, and the remaining nodes represent the environment. The graph provides gym-like .reset and .step methods that mirror reinforcement learning interfaces: - .reset: Initializes the system and returns the initial observation as would be seen by the supervisor node. - .step: Advances the simulation by one step (i.e. steps all nodes except the supervisor) and returns the next observation.

The beauty of this design lies in its flexibility. By selecting different supervisor nodes, you can create learning environments from varying perspectives: - Agent as Supervisor: Forms a traditional reinforcement learning environment. - Sensor as Supervisor: Creates an interface where the .reset and .step methods return the sensor's inputs, simulating the I/O process from the sensor's viewpoint.

# @title Creating the Graph
# @markdown With the nodes defined and connected, we can create a graph.

from rex.asynchronous import AsyncGraph
from rex.constants import Clock, RealTimeFactor


# Create the graph by specifying the nodes and the supervisor node
graph = AsyncGraph(
    nodes=nodes,
    supervisor=agent,
    # Settings for simulating at fast as possible speed according to specified delays
    clock=Clock.SIMULATED,
    real_time_factor=RealTimeFactor.FAST_AS_POSSIBLE,
    # Settings for simulating at real-time speed according to specified delays
    # clock=Clock.SIMULATED, real_time_factor=RealTimeFactor.REAL_TIME,
    # Settings for real-world deployment
    # clock=Clock.WALL_CLOCK, real_time_factor=RealTimeFactor.REAL_TIME,
)

Interacting with the Graph¤

After creating the graph, we can interact with it using the provided APIs to initialize, reset, and step through the graph.

# @title Initializing the Graph
# @markdown Before starting an episode, we initialize the graph's state.
# @markdown This prepares the graph for execution by initializing the parameters, states, and inputs of all nodes.
# @markdown If we must initialize in a specific order, we can specify the order of node initialization.
# @markdown We also compile ahead-of-time the step functions of all nodes to speed up execution, where we can specify the devices for each node.

# Import JAX random number generator
import jax


# Initialize the graph state
rng = jax.random.PRNGKey(0)  # Optional random number generator for reproducibility

# Start initialization with the agent node. This is important as the world's state
# depends on the initial theta and thdot sampled in agent.init_state(...)
initial_graph_state = graph.init(rng=rng, order=["agent"])

# Specify what we want to record (params, state, output) for each node,
graph.set_record_settings(params=True, inputs=False, state=True, output=True)

# Ahead-of-time compilation of all step functions
# Compile the step functions of all nodes to speed up execution.
# Specify the devices for each node, placing them on the CPU or GPU.
from rex.constants import LogLevel
from rex.utils import set_log_level


# Place all nodes on the CPU, except the agent, which is placed on the GPU (if available)
[set_log_level(LogLevel.DEBUG, n) for n in nodes.values()]  # Silence the log output
devices_step = {k: next(cpus) if k != "agent" or gpu is None else gpu for k in nodes}
graph.warmup(initial_graph_state, devices_step, jit_step=True, profile=True)  # Profile=True for profiling the step function
# @title Graph interaction (Gym-Like API)
# @markdown We use the graph state obtained with .init() and perform step-by-step execution with .reset() and .step().
# @markdown Finally, we get the recorded episode data for analysis.
import jax.numpy as jnp
import tqdm  # Used for progress bars


# Starts the graph with the initial state and returns the supervisor's initial step state.
# If nodes have specified `.startup` methods, they will be called here as well.
graph_state, initial_step_state = graph.reset(initial_graph_state)
step_state = initial_step_state  # The supervisor's step state
for i in tqdm.tqdm(range(300), desc="gather data"):
    # Access the last sensor message of the input buffer
    # -1 is the most recent message, -2 the second most recent, etc. up until the window size
    sensor_msg = step_state.inputs["sensor"][-1].data  # .data grabs the pytree message object
    action = jnp.array([0.5])  # Replace with actual action
    output = step_state.params.to_output(action)  # Convert the action to an output message
    # Step the graph (i.e., executes the next time step by sending the output message to the actuator node)
    graph_state, step_state = graph.step(graph_state, step_state, output)  # Step the graph with the agent's output
graph.stop()  # Stops all nodes that were running asynchronously in the background

# Get the episode data (params, delays, outputs, etc.)
record = graph.get_record()  # Gets the records of all nodes
gather data: 100%|██████████| 300/300 [00:08<00:00, 36.60it/s]

# @title Visualizing the Dataflow of an Episode
# @markdown We can visualize the dataflow of the episode to understand the interactions between nodes.
# @markdown The top plot shows how long each node takes to process data and forward it to the next node.
# @markdown The bottom plot provides a graph representation that will form the basis for the computational graph used for compilation.
# @markdown - Each vertex represents a step call of a node, and each edge represents message transmission between two nodes.
# @markdown - Edges between consecutive steps of the same node represent the transmission of the internal state of the node.
# @markdown - Nodes start processing after an initial phase-shift, which can be controlled in the node definition.

import supergraph  # Used for visualizing the graph

import rex.utils as rutils


# Convert the episode data to a data flow graph
df = record.to_graph()
timing_mode = "arrival"  # "arrival" or "usage"
G = rutils.to_networkx_graph(df, nodes=nodes)
fig, axes = plt.subplots(2, 1, figsize=(12, 6))
rutils.plot_graph(
    G,
    max_x=0.5,
    ax=axes[0],
    message_arrow_timing_mode=timing_mode,
    edge_linewidth=1.4,
    arrowsize=10,
    show_labels=True,
    height=0.6,
    label_loc="center",
)
supergraph.plot_graph(G, max_x=0.5, ax=axes[1])
fig.suptitle("Data flow of one episode")
axes[-1].set_xlabel("Time [s]");
# @title Creating a Compiled Graph
# @markdown Next, we create a compiled graph to speed up execution by pre-compiling the dataflow graph.
# @markdown This approach requires a recording of the dataflow graph during a simulation episode.
# @markdown By simulating an episode according to the exact same dataflow graph, we include the asynchronous effects of delays.

# Initialize a graph that can be compiled and enables parallelized execution
cgraph = rex.graph.Graph(nodes, nodes["agent"], df)
Growing supergraph: 100%|██████████| 301/301 [00:00<00:00, 487.16it/s, 1/1 graphs, 1210/1210 matched (67.00% efficiency, 6 nodes (pre-filtered: 6 nodes))]

# @title Simulating a Compiled Graph
# @markdown A compiled graph has the same API as a regular graph (init, reset, step, run).
# @markdown However, we can also simulate entire rollouts in an optimized manner.
# @markdown Here, we simulate multiple rollouts in parallel to speed up the simulation process.
num_rollouts = 10_000


# Define a function for rolling out the graph that can be compiled and executed in parallel
def rollout_fn(rng):
    # Initialize graph state
    gs = cgraph.init(rng, order=("agent",))
    # Make sure to record the states
    gs = cgraph.init_record(gs, params=False, state=True, output=False)
    # Run the graph for a fixed number of steps
    gs_final = cgraph.rollout(gs)
    # This returns a record that may only be partially filled.
    record = gs_final.aux["record"]
    is_filled = record.nodes["world"].steps.seq >= 0  # Unfilled steps are marked with -1
    return is_filled, record.nodes["world"].steps.state


# Prepare timers
timer_jit = rutils.timer(f"Vectorized evaluation of {num_rollouts} rollouts | compile", log_level=100)
timer_run = rutils.timer(f"Vectorized evaluation of {num_rollouts} rollouts | rollouts", log_level=100)

# Run the rollouts in parallel
rng, rng_rollout = jax.random.split(rng)
rngs_rollout = jax.random.split(rng_rollout, num=num_rollouts)
with timer_jit:
    rollout_fn_jv = jax.jit(jax.vmap(rollout_fn))
    rollout_fn_jv = rollout_fn_jv.lower(rngs_rollout)
    rollout_fn_jv = rollout_fn_jv.compile()
with timer_run:
    is_filled, final_states = rollout_fn_jv(rngs_rollout)
    final_states.th.block_until_ready()

# Only keep the filled rollouts (we did not run the full duration of the computation graph)
final_states = final_states[is_filled]
print(
    f"sim. eval | fps: {(num_rollouts * cgraph.max_steps) / timer_run.duration / 1e6:.0f} Million steps/s | compile: {timer_jit.duration:.2f} s | run: {timer_run.duration:.2f} s"
)
[434  ][MainThread               ][tracer              ][Vectorized evaluation of 10000 rollouts | compile] Elapsed: 4.8439 sec
[434  ][MainThread               ][tracer              ][Vectorized evaluation of 10000 rollouts | rollouts] Elapsed: 0.0998 sec
sim. eval | fps: 30 Million steps/s | compile: 4.84 s | run: 0.10 s

Defining an Environment¤

To integrate your graph within a reinforcement learning environment or other systems, define an environment class that interacts with the graph. RL algorithms such as the one defined in rex.ppo requires an environment that implements the following methods:

Implementing the Environment Class¤

  • observation_space: Describes the observation space of the environment.
  • action_space: Describes the action space of the environment.
  • max_steps: The maximum number of step the environment can run (i.e. episode length). When using a compiled Graph, this is constrained by the length of the recorded episode.
  • reset: Prepares the environment for a new episode by initializing and resetting the graph.
  • step: Advances the environment by one timestep, applying the provided action and returning the new observation and reward.
# @title Example: Pendulum swing-up environment

from typing import Any, Dict, Union

import jax
import jax.numpy as jnp

from rex import base
from rex.examples.pendulum.agent import AgentParams
from rex.graph import Graph
from rex.rl import BaseEnv, Box, ResetReturn, StepReturn


class SwingUpEnv(BaseEnv):
    def __init__(self, graph: Graph):
        super().__init__(graph=graph)
        self._init_params = {}

    @property
    def max_steps(self) -> Union[int, jax.typing.ArrayLike]:
        """Maximum number of steps in an evaluation episode"""
        return int(3.5 * self.graph.nodes["agent"].rate)

    def set_params(self, params: Dict[str, Any]):
        """Pre-set parameters for the environment"""
        self._init_params.update(params)

    def observation_space(self, graph_state: base.GraphState) -> Box:
        cdata = self.get_observation(graph_state)
        low = jnp.full(cdata.shape, -1e6)
        high = jnp.full(cdata.shape, 1e6)
        return Box(low, high, shape=cdata.shape, dtype=cdata.dtype)

    def action_space(self, graph_state: base.GraphState) -> Box:
        params: AgentParams = graph_state.params["agent"]
        high = jnp.array([params.max_torque], dtype=jnp.float32)
        return Box(-high, high, shape=high.shape, dtype=high.dtype)

    def get_observation(self, graph_state: base.GraphState) -> jax.Array:
        # Flatten all inputs and state of the supervisor as the observation
        ss = graph_state.step_state["agent"]
        params: AgentParams = ss.params
        obs = params.get_observation(ss)
        return obs

    def reset(self, rng: jax.Array = None) -> ResetReturn:
        # Initialize the graph state
        init_gs = self.graph.init(rng=rng, params=self._init_params, order=("agent",))
        # Run the graph until the agent node
        gs, _ = self.graph.reset(init_gs)
        # Get observation
        obs = self.get_observation(gs)
        info = {}  # No info to return
        return gs, obs, info

    def step(self, graph_state: base.GraphState, action: jax.Array) -> StepReturn:
        params: AgentParams = graph_state.params["agent"]
        # Update the agent's state (i.e. action and observation history)
        new_agent = params.update_state(graph_state.step_state["agent"], action)
        # The loss_task (i.e. reward) is accumulated in the World node's step function
        # Hence, we read out the loss_task from the world node and set it to 0 before stepping
        # This is to ensure that the loss_task is only counted once
        # Note that this is not obligatory, but it's a good way to ensure that the reward is consistent in the
        # face of simulated asynchrounous effects.
        new_world = graph_state.state["world"].replace(loss_task=0.0)
        # Update the states in the graph state
        gs = graph_state.replace(state=graph_state.state.copy({"agent": new_agent, "world": new_world}))
        # Convert action to output (i.e. the one that the Agent node outputs)
        ss = gs.step_state["agent"]
        output = params.to_output(action)
        # Step the graph (i.e. all nodes except the Agent node)
        next_gs, next_ss = self.graph.step(gs, ss, output)
        # Get observation
        obs = self.get_observation(next_gs)
        info = {}
        # Read out the loss_task from the world node's state
        reward = -graph_state.state["world"].loss_task
        # Determine if the episode is truncated
        terminated = False  # Infinite horizon task
        truncated = params.tmax <= next_ss.ts  # Truncate if the time limit is reached
        # Mitigate truncation of infinite horizon tasks by adding a final reward
        # Add the steady-state solution as if the agent had stayed in the same state for the rest of the episode
        gamma = params.gamma
        reward_final = truncated * (1 / (1 - gamma)) * reward  # Assumes that the reward is constant after truncation
        reward = reward + reward_final
        return next_gs, obs, reward, terminated, truncated, info
# @title Example: Training a PPO agent
# @markdown We can now train a PPO agent on the defined environment.
# @markdown In fact, we do so in parallel with 5 policies to speed up training.

import functools

import rex.ppo as ppo


# Create the environment
env = SwingUpEnv(cgraph)

# Configure the PPO agent
config = ppo.Config(
    LR=0.0003261962464827655,
    NUM_ENVS=128,
    NUM_STEPS=32,
    TOTAL_TIMESTEPS=5e6,
    UPDATE_EPOCHS=8,
    NUM_MINIBATCHES=16,
    GAMMA=0.9939508937435216,
    GAE_LAMBDA=0.9712149137900143,
    CLIP_EPS=0.16413213812946092,
    ENT_COEF=0.01,
    VF_COEF=0.8015258840683805,
    MAX_GRAD_NORM=0.9630061315073456,
    NUM_HIDDEN_LAYERS=2,
    NUM_HIDDEN_UNITS=64,
    KERNEL_INIT_TYPE="xavier_uniform",
    HIDDEN_ACTIVATION="tanh",
    STATE_INDEPENDENT_STD=True,
    SQUASH=True,
    ANNEAL_LR=False,
    NORMALIZE_ENV=True,
    FIXED_INIT=True,
    OFFSET_STEP=False,
    NUM_EVAL_ENVS=20,
    EVAL_FREQ=20,
    VERBOSE=True,
    DEBUG=False,
)

# Train 5 policies in parallel
rng, rng_train = jax.random.split(rng)
rngs_train = jax.random.split(rng_train, num=5)  # Train 5 policies in parallel
train = functools.partial(ppo.train, env)
with rutils.timer("ppo | compile"):
    train_v = jax.vmap(train, in_axes=(None, 0))
    train_vjit = jax.jit(train_v)
    train_vjit = train_vjit.lower(config, rngs_train).compile()
with rutils.timer("ppo | train"):
    res = train_vjit(config, rngs_train)
[434  ][MainThread               ][tracer              ][ppo | compile       ] Elapsed: 40.6049 sec
train_steps=249856 | eval_eps=20 | return=-776.5+-0.0 | length=147+-0.0 | approxkl=0.0033
train_steps=249856 | eval_eps=20 | return=-725.1+-0.0 | length=147+-0.0 | approxkl=0.0038
train_steps=249856 | eval_eps=20 | return=-722.7+-0.0 | length=147+-0.0 | approxkl=0.0038
train_steps=249856 | eval_eps=20 | return=-1002.3+-0.0 | length=147+-0.0 | approxkl=0.0039
train_steps=249856 | eval_eps=20 | return=-1490.4+-0.0 | length=147+-0.0 | approxkl=0.0037
train_steps=499712 | eval_eps=20 | return=-696.3+-0.0 | length=147+-0.0 | approxkl=0.0028
train_steps=499712 | eval_eps=20 | return=-702.2+-0.0 | length=147+-0.0 | approxkl=0.0031
train_steps=499712 | eval_eps=20 | return=-627.2+-0.0 | length=147+-0.0 | approxkl=0.0029
train_steps=499712 | eval_eps=20 | return=-613.8+-0.0 | length=147+-0.0 | approxkl=0.0029
train_steps=499712 | eval_eps=20 | return=-686.7+-0.0 | length=147+-0.0 | approxkl=0.0029
train_steps=749568 | eval_eps=20 | return=-630.4+-0.0 | length=147+-0.0 | approxkl=0.0031
train_steps=749568 | eval_eps=20 | return=-648.7+-0.0 | length=147+-0.0 | approxkl=0.0027
train_steps=749568 | eval_eps=20 | return=-685.1+-0.0 | length=147+-0.0 | approxkl=0.0029
train_steps=749568 | eval_eps=20 | return=-619.9+-0.0 | length=147+-0.0 | approxkl=0.0028
train_steps=749568 | eval_eps=20 | return=-1561.8+-0.0 | length=147+-0.0 | approxkl=0.0027
train_steps=999424 | eval_eps=20 | return=-761.9+-0.0 | length=147+-0.0 | approxkl=0.0027
train_steps=999424 | eval_eps=20 | return=-637.3+-0.0 | length=147+-0.0 | approxkl=0.0025
train_steps=999424 | eval_eps=20 | return=-681.2+-0.0 | length=147+-0.0 | approxkl=0.0029
train_steps=999424 | eval_eps=20 | return=-674.9+-0.0 | length=147+-0.0 | approxkl=0.0026
train_steps=999424 | eval_eps=20 | return=-1585.7+-0.0 | length=147+-0.0 | approxkl=0.0025
train_steps=1249280 | eval_eps=20 | return=-752.2+-0.0 | length=147+-0.0 | approxkl=0.0028
train_steps=1249280 | eval_eps=20 | return=-699.8+-0.0 | length=147+-0.0 | approxkl=0.0026
train_steps=1249280 | eval_eps=20 | return=-1091.1+-0.0 | length=147+-0.0 | approxkl=0.0031
train_steps=1249280 | eval_eps=20 | return=-650.8+-0.0 | length=147+-0.0 | approxkl=0.0027
train_steps=1249280 | eval_eps=20 | return=-1087.9+-0.0 | length=147+-0.0 | approxkl=0.0024
train_steps=1499136 | eval_eps=20 | return=-548.7+-0.0 | length=147+-0.0 | approxkl=0.0028
train_steps=1499136 | eval_eps=20 | return=-844.7+-0.0 | length=147+-0.0 | approxkl=0.0025
train_steps=1499136 | eval_eps=20 | return=-657.3+-0.0 | length=147+-0.0 | approxkl=0.0029
train_steps=1499136 | eval_eps=20 | return=-591.9+-0.0 | length=147+-0.0 | approxkl=0.0025
train_steps=1499136 | eval_eps=20 | return=-645.2+-0.0 | length=147+-0.0 | approxkl=0.0024
train_steps=1748992 | eval_eps=20 | return=-590.2+-0.0 | length=147+-0.0 | approxkl=0.0026
train_steps=1748992 | eval_eps=20 | return=-851.7+-0.0 | length=147+-0.0 | approxkl=0.0027
train_steps=1748992 | eval_eps=20 | return=-639.6+-0.0 | length=147+-0.0 | approxkl=0.0031
train_steps=1748992 | eval_eps=20 | return=-554.2+-0.0 | length=147+-0.0 | approxkl=0.0026
train_steps=1748992 | eval_eps=20 | return=-664.0+-0.0 | length=147+-0.0 | approxkl=0.0027
train_steps=1998848 | eval_eps=20 | return=-638.6+-0.0 | length=147+-0.0 | approxkl=0.0027
train_steps=1998848 | eval_eps=20 | return=-662.4+-0.0 | length=147+-0.0 | approxkl=0.0026
train_steps=1998848 | eval_eps=20 | return=-690.1+-0.0 | length=147+-0.0 | approxkl=0.0032
train_steps=1998848 | eval_eps=20 | return=-1450.2+-0.0 | length=147+-0.0 | approxkl=0.0025
train_steps=1998848 | eval_eps=20 | return=-961.3+-0.0 | length=147+-0.0 | approxkl=0.0026
train_steps=2248704 | eval_eps=20 | return=-687.5+-0.0 | length=147+-0.0 | approxkl=0.0026
train_steps=2248704 | eval_eps=20 | return=-561.8+-0.0 | length=147+-0.0 | approxkl=0.0025
train_steps=2248704 | eval_eps=20 | return=-508.9+-0.0 | length=147+-0.0 | approxkl=0.0034
train_steps=2248704 | eval_eps=20 | return=-559.8+-0.0 | length=147+-0.0 | approxkl=0.0027
train_steps=2248704 | eval_eps=20 | return=-604.6+-0.0 | length=147+-0.0 | approxkl=0.0027
train_steps=2498560 | eval_eps=20 | return=-1182.2+-0.0 | length=147+-0.0 | approxkl=0.0026
train_steps=2498560 | eval_eps=20 | return=-630.9+-0.0 | length=147+-0.0 | approxkl=0.0025
train_steps=2498560 | eval_eps=20 | return=-720.0+-0.0 | length=147+-0.0 | approxkl=0.0031
train_steps=2498560 | eval_eps=20 | return=-569.4+-0.0 | length=147+-0.0 | approxkl=0.0026
train_steps=2498560 | eval_eps=20 | return=-419.6+-0.0 | length=147+-0.0 | approxkl=0.0034
train_steps=2748416 | eval_eps=20 | return=-567.8+-0.0 | length=147+-0.0 | approxkl=0.0026
train_steps=2748416 | eval_eps=20 | return=-552.8+-0.0 | length=147+-0.0 | approxkl=0.0026
train_steps=2748416 | eval_eps=20 | return=-626.8+-0.0 | length=147+-0.0 | approxkl=0.0032
train_steps=2748416 | eval_eps=20 | return=-563.6+-0.0 | length=147+-0.0 | approxkl=0.0026
train_steps=2748416 | eval_eps=20 | return=-385.8+-0.0 | length=147+-0.0 | approxkl=0.0038
train_steps=2998272 | eval_eps=20 | return=-1553.7+-0.0 | length=147+-0.0 | approxkl=0.0025
train_steps=2998272 | eval_eps=20 | return=-765.8+-0.0 | length=147+-0.0 | approxkl=0.0027
train_steps=2998272 | eval_eps=20 | return=-623.3+-0.0 | length=147+-0.0 | approxkl=0.0033
train_steps=2998272 | eval_eps=20 | return=-688.6+-0.0 | length=147+-0.0 | approxkl=0.0029
train_steps=2998272 | eval_eps=20 | return=-387.5+-0.0 | length=147+-0.0 | approxkl=0.0042
train_steps=3248128 | eval_eps=20 | return=-614.8+-0.0 | length=147+-0.0 | approxkl=0.0025
train_steps=3248128 | eval_eps=20 | return=-495.8+-0.0 | length=147+-0.0 | approxkl=0.0028
train_steps=3248128 | eval_eps=20 | return=-569.3+-0.0 | length=147+-0.0 | approxkl=0.0034
train_steps=3248128 | eval_eps=20 | return=-607.6+-0.0 | length=147+-0.0 | approxkl=0.0029
train_steps=3248128 | eval_eps=20 | return=-395.1+-0.0 | length=147+-0.0 | approxkl=0.0044
train_steps=3497984 | eval_eps=20 | return=-559.6+-0.0 | length=147+-0.0 | approxkl=0.0027
train_steps=3497984 | eval_eps=20 | return=-861.6+-0.0 | length=147+-0.0 | approxkl=0.0027
train_steps=3497984 | eval_eps=20 | return=-578.0+-0.0 | length=147+-0.0 | approxkl=0.0033
train_steps=3497984 | eval_eps=20 | return=-613.4+-0.0 | length=147+-0.0 | approxkl=0.0029
train_steps=3497984 | eval_eps=20 | return=-382.8+-0.0 | length=147+-0.0 | approxkl=0.0044
train_steps=3747840 | eval_eps=20 | return=-575.2+-0.0 | length=147+-0.0 | approxkl=0.0027
train_steps=3747840 | eval_eps=20 | return=-518.3+-0.0 | length=147+-0.0 | approxkl=0.0027
train_steps=3747840 | eval_eps=20 | return=-565.4+-0.0 | length=147+-0.0 | approxkl=0.0032
train_steps=3747840 | eval_eps=20 | return=-679.5+-0.0 | length=147+-0.0 | approxkl=0.0030
train_steps=3747840 | eval_eps=20 | return=-379.2+-0.0 | length=147+-0.0 | approxkl=0.0047
train_steps=3997696 | eval_eps=20 | return=-544.8+-0.0 | length=147+-0.0 | approxkl=0.0027
train_steps=3997696 | eval_eps=20 | return=-556.3+-0.0 | length=147+-0.0 | approxkl=0.0028
train_steps=3997696 | eval_eps=20 | return=-547.6+-0.0 | length=147+-0.0 | approxkl=0.0034
train_steps=3997696 | eval_eps=20 | return=-955.2+-0.0 | length=147+-0.0 | approxkl=0.0030
train_steps=3997696 | eval_eps=20 | return=-376.4+-0.0 | length=147+-0.0 | approxkl=0.0050
train_steps=4247552 | eval_eps=20 | return=-595.6+-0.0 | length=147+-0.0 | approxkl=0.0028
train_steps=4247552 | eval_eps=20 | return=-552.1+-0.0 | length=147+-0.0 | approxkl=0.0028
train_steps=4247552 | eval_eps=20 | return=-1352.2+-0.0 | length=147+-0.0 | approxkl=0.0034
train_steps=4247552 | eval_eps=20 | return=-552.7+-0.0 | length=147+-0.0 | approxkl=0.0029
train_steps=4247552 | eval_eps=20 | return=-376.0+-0.0 | length=147+-0.0 | approxkl=0.0052
train_steps=4497408 | eval_eps=20 | return=-549.1+-0.0 | length=147+-0.0 | approxkl=0.0029
train_steps=4497408 | eval_eps=20 | return=-568.6+-0.0 | length=147+-0.0 | approxkl=0.0030
train_steps=4497408 | eval_eps=20 | return=-453.7+-0.0 | length=147+-0.0 | approxkl=0.0035
train_steps=4497408 | eval_eps=20 | return=-611.0+-0.0 | length=147+-0.0 | approxkl=0.0028
train_steps=4497408 | eval_eps=20 | return=-374.9+-0.0 | length=147+-0.0 | approxkl=0.0057
train_steps=4747264 | eval_eps=20 | return=-561.1+-0.0 | length=147+-0.0 | approxkl=0.0030
train_steps=4747264 | eval_eps=20 | return=-480.1+-0.0 | length=147+-0.0 | approxkl=0.0030
train_steps=4747264 | eval_eps=20 | return=-1391.0+-0.0 | length=147+-0.0 | approxkl=0.0036
train_steps=4747264 | eval_eps=20 | return=-645.8+-0.0 | length=147+-0.0 | approxkl=0.0030
train_steps=4747264 | eval_eps=20 | return=-372.7+-0.0 | length=147+-0.0 | approxkl=0.0068
train_steps=4997120 | eval_eps=20 | return=-538.1+-0.0 | length=147+-0.0 | approxkl=0.0030
train_steps=4997120 | eval_eps=20 | return=-650.3+-0.0 | length=147+-0.0 | approxkl=0.0030
train_steps=4997120 | eval_eps=20 | return=-775.1+-0.0 | length=147+-0.0 | approxkl=0.0035
train_steps=4997120 | eval_eps=20 | return=-564.2+-0.0 | length=147+-0.0 | approxkl=0.0033
train_steps=4997120 | eval_eps=20 | return=-370.8+-0.0 | length=147+-0.0 | approxkl=0.0083
[434  ][MainThread               ][tracer              ][ppo | train         ] Elapsed: 40.3077 sec

# @title Visualize PPO Training Progress
# @markdown The plots below show the training progress of the PPO algorithm in terms of returns and policy KL divergence.
# @markdown Note that we are not solving the swing-up task here, but rather demonstrating the training process.
# @markdown See `sim2real.ipynb` for a complete example of solving the swing-up task using PPO.

fig_ppo, axes_ppo = plt.subplots(1, 2, figsize=(8, 3))
total_steps = res.metrics["train/total_steps"].transpose()
mean, std = res.metrics["eval/mean_returns"].transpose(), res.metrics["eval/std_returns"].transpose()
axes_ppo[0].plot(total_steps, mean, label="mean")
axes_ppo[0].set_title("Returns")
axes_ppo[0].set_xlabel("Total steps")
axes_ppo[0].set_ylabel("Cum. return")
mean, std = res.metrics["train/mean_approxkl"].transpose(), res.metrics["train/std_approxkl"].transpose()
axes_ppo[1].plot(total_steps, mean, label="mean")
axes_ppo[1].set_title("Policy KL")
axes_ppo[1].set_xlabel("Total steps")
axes_ppo[1].set_ylabel("Approx. kl");