Getting Started¤
Rex is a JAX-powered framework for sim-to-real robotics.
Key features:
- Graph-based design: Model asynchronous systems with nodes for sensing, actuation, and computation.
- Latency-aware modeling: Simulate delay effects for hardware, computation, and communication channels.
- Real-time and parallelized runtimes: Run real-world experiments or accelerated parallelized simulations.
- Seamless integration with JAX: Utilize JAX's autodiff, JIT compilation, and GPU/TPU acceleration.
- System identification tools: Estimate dynamics and delays directly from real-world data.
- Modular and extensible: Compatible with various simulation engines (e.g., Brax, MuJoCo).
- Unified sim2real pipeline: Train delay-aware policies in simulation and deploy them on real-world systems.
Sim-to-Real Workflow¤
- Interface Real Systems: Define nodes for sensors, actuators, and computation to represent real-world systems.
- Build Simulation: Swap real-world nodes with simulated ones (e.g., physics engines, motor dynamics).
- System Identification: Estimate system dynamics and delays from real-world data.
- Policy Training: Train delay-aware policies in simulation, accounting for realistic dynamics and delays.
- Evaluation: Evaluate trained policies on the real-world system, and iterate on the design.
Installation¤
pip install rex-lib
Requires Python 3.9+ and JAX 0.4.30+.
Quick example¤
Here's a simple example of a pendulum system. The real-world system is defined with nodes interfacing hardware for sensing, actuation:
from rex.asynchronous import AsyncGraph
from rex.examples.pendulum import Actuator, Agent, Sensor
sensor = Sensor(rate=50) # 50 Hz sampling rate
agent = Agent(rate=30) # 30 Hz policy execution rate
actuator = Actuator(rate=50) # 50 Hz control rate
nodes = dict(sensor=sensor, agent=agent, actuator=actuator)
agent.connect(sensor) # Agent receives sensor data
actuator.connect(agent) # Actuator receives agent commands
graph = AsyncGraph(nodes, agent) # Graph for real-world execution
graph_state = graph.init() # Initial states of all nodes
graph.warmup(graph_state) # Jit-compiles the graph (only once).
for _ in range(100): # Run the graph for 100 steps
graph_state = graph.run(graph_state) # Run for one step
graph.stop() # Stop asynchronous nodes
data = graph.get_record() # Get recorded data from the graph
from distrax import Normal
from rex.constants import Clock, RealTimeFactor
from rex.asynchronous import AsyncGraph
from rex.examples.pendulum import SimActuator, Agent, SimSensor, BraxWorld
sensor = SimSensor(rate=50, delay_dist=Normal(0.01, 0.001)) # Process delay
agent = Agent(rate=30, delay_dist=Normal(0.02, 0.005)) # Computational delay
actuator = SimActuator(rate=50, delay_dist=Normal(0.01, 0.001)) # Process delay
world = BraxWorld(rate=100) # 100 Hz physics simulation
nodes = dict(sensor=sensor, agent=agent, actuator=actuator, world=world)
sensor.connect(world, delay_dist=Normal(0.001, 0.001)) # Sensor delay
agent.connect(sensor, delay_dist=Normal(0.001, 0.001)) # Communication delay
actuator.connect(agent, delay_dist=Normal(0.001, 0.001)) # Communication delay
world.connect(actuator, delay_dist=Normal(0.001, 0.001), # Actuator delay
skip=True) # Breaks algebraic loop in the graph
graph = AsyncGraph(nodes, agent,
clock=Clock.SIMULATED, # Simulates based on delay_dist
real_time_factor=RealTimeFactor.FAST_AS_POSSIBLE)
graph_state = graph.init() # Initial states of all nodes
graph.warmup(graph_state) # Jit-compiles the graph
for _ in range(100): # Run the graph for 100 steps
graph_state = graph.run(graph_state) # Run for one step
graph.stop() # Stop asynchronous nodes
data = graph.get_record() # Get recorded data from the graph
from rex.node import BaseNode
class Agent(BaseNode):
def init_params(self, rng=None, graph_state=None):
return SomePyTree(a=..., b=...)
def init_state(self, rng=None, graph_state=None):
return SomePyTree(x1=..., x2=...)
def init_output(self, rng=None, graph_state=None):
return SomePyTree(y1=..., y2=...)
# Jit-compiled via graph.warmup for faster execution
def step(self, step_state): # Called at Node's rate
ss = step_state # Shorten name
# Read params, and current state
params, state = ss.params, ss.state
# Current episode, sequence, timestamp
eps, seq, ts = ss.eps, ss.seq, ss.ts
# Grab the data, and I/O timestamps
cam = ss.inputs["sensor"] # Received messages
cam.data, cam.ts_send, cam.ts_recv
... # Some computation for new_state, output
new_state = SomePyTree(x1=..., x2=...)
output = SomePyTree(y1=..., y2=...)
# Update step_state for next step call
new_ss = ss.replace(state=new_state)
return new_ss, output # Sends output
Next steps¤
If this quick start has got you interested, then have a look at the sim2real.ipynb notebook for an example of a sim-to-real workflow using Rex.
Citation¤
If you found this library to be useful in academic work, then please cite this paper:
@article{heijden2024rex,
title={{REX: GPU-Accelerated Sim2Real Framework with Delay and Dynamics Estimation}},
author={van der Heijden, Bas and Kober, Jens and Babuska, Robert and Ferranti, Laura},
journal={Transactions on Machine Learning Research (TMLR)},
year={2025}
}
(Also consider starring the project on GitHub.)