An MJX-style JAX FFI for CPU-based MuJoCo
The feature, motivation and pitch
I am training reinforcement learning agents in Jax and currently use MJX as the simulator. The MJX interface is really nice to work with, as it allows me to reason about a single environment at a time using familiar MuJoCo syntax while leveraging jax.jit and jax.vmap for vectorization and compilation. This results in both nice-to-author and, in most cases, performant environments.
However, while MJX performs well for simpler environments, its speed scales poorly for more complex scenes with many objects and contacts. This limitation manifests not only in performance (e.g., poor collision scaling) but also in usability. The Jax/MJX compilation is often slow, which makes debugging, visualization, and testing quite tedious compared to iterating with CPU-based MuJoCo.
I think a promising solution could be to introduce a clean, vectorized interface for CPU MuJoCo that mirrors the MJX API. This could be achieved by using Jax FFI to efficiently manage and step a batch of mujoco.Data instances. The ideal interface would expose functions like mjx.step and mjx.forward and data structures like mjx.Data and mjx.Model, but with the simulation backend being the CPU-based MuJoCo.
I think this feature could provide the best of both MJX and MuJoCo worlds:
- The clean, single-program, multiple-data (SPMD) experience authoring environments (like MJX enables via
jax.{jit,vmap}). - The feature-completeness, scalability, and user-friendliness of the CPU-based MuJoCo.
Alternatives
EnvPool
EnvPool is quite similar in spirit and achieves pretty impressive throughput on CPU. It provides an Jax FFI interface for the environment but it requires the environment logic itself to be written in C++. What I'm requesting here is a bit different in that I want define the environment logic in Python/Jax, leveraging the existing MJX/Jax patterns for clean vectorization via jax.{vmap,jit}. EnvPool also appears to be largely unmaintained.
mujoco.rollout
One could theoretically combine mjx.{Model,Data} with mujoco.rollout to achieve a similar outcome. However, my attempts so far have been a bit clumsy and results in unnecessary overhead from Python loops and data marshalling between JAX and NumPy. An optimized FFI-based implementation could probably eliminate these bottlenecks and make the user experience cleaner.
For example, I’ve experimented with something like this:
Details
model: mujoco.Model = ...
datas = [mujoco.Data(model) for _ in range(batch_size)]
modelx = mjx.put_model(model)
datax: mjx.Data = ... # batched data
# This rollout happens on the CPU via Python looping
states, _ = rollout.rollout(
model,
datas,
state,
control,
nstep=num_substeps,
)
# Data must be manually transferred back to the JAX side
state = states[:, -1, :]
datax = mjx_set_state(modelx, datax, state) # hypothetical function
datax = mjx.forward(modelx, datax)
...
Technically, the above allows me to write the environments in Jax but it’s still a bit clumsy. A native FFI interface would make this setup whole a lot cleaner.
MjWarp
MjWarp seems promising in improving both collision scaling and compilation/runtime speeds. However, it is still in beta, its final performance/scalability are unclear, and it does not yet have full feature parity with vanilla MuJoCo. I think a vectorized CPU backend would be a valuable, complementary tool, enabling the stability, feature set, and the usability of MuJoco. In an ideal world one could author environments in Jax and then swap the backend between MjWarp and CPU-based MuJoCo with barely any changes to the code.
Additional context
I hope this feature request is not completely out of left field! I think it would be a valuable tool for the community and I'd love to hear what you all think. I’d also be interested in contributing to the implementation if others think it would be a good fit for MuJoCo.
Hi @hartikainen,
I've actually written something like this a few days ago and so far it's working quite well, (but I'm still testing it, so no guarantees :-)). The trick to making it work well is using Jax's external_callback to wrap around rollout, and a custom vmap to stack everything. On the Jax side of things I'm relying on Mjx.Data, to which the rollout applies updates. rollout only returns the fullphysics state, but I need access to some other properties such as xpos, so I copy the underlying datastructure. The code also automatically batches the number of envs by the number of threads and scans over them to prevent oversubscription of threads.
Here's the code:
from typing import NamedTuple
import jax
import jax.numpy as jnp
import mujoco
import mujoco.mjx as mjx
import numpy as np
from flax.struct import PyTreeNode, field
from jax import Array
from mujoco.rollout import Rollout
# %%
def get_full_physics_state(d: mjx.Data) -> Array:
# NOTE: Plugin state doesn't exist
return jnp.concat(
[
jnp.atleast_1d(d.time),
jnp.atleast_1d(d.qpos),
jnp.atleast_1d(d.qvel),
jnp.atleast_1d(d.act),
],
axis=-1,
)
class Updates(NamedTuple):
fullphysicsstate: Array | np.ndarray
sensordata: Array | np.ndarray
xpos: Array | np.ndarray
xquat: Array | np.ndarray
xmat: Array | np.ndarray
def apply_updates(d: mjx.Data, updates: Updates) -> mjx.Data:
# Split full physics
sizes = [1, d.qpos.shape[-1], d.qvel.shape[-1]]
time, qpos, qvel, act = jnp.split(
updates.fullphysicsstate, np.cumsum(sizes), axis=-1
)
xmat = updates.xmat.reshape(*updates.xmat.shape[:-1], 3, 3)
return d.replace(
xpos=updates.xpos,
xquat=updates.xquat,
xmat=xmat,
sensordata=updates.sensordata,
time=time,
qpos=qpos,
qvel=qvel,
act=act,
)
def mujoco_rollout_forward(
threadpool: Rollout,
m: mujoco.MjModel,
d: list[mujoco.MjData],
initial_state: Array,
control: Array,
) -> Updates:
def _rollout(initial_state, control):
# Rollout state
final_state, sensordata = threadpool.rollout(
m,
d,
np.asarray(initial_state),
np.asarray(control),
)
xpos = np.empty((threadpool.nthread, m.nbody, 3), dtype=np.float32)
xmat = np.empty((threadpool.nthread, m.nbody, 9), dtype=np.float32)
xquat = np.empty((threadpool.nthread, m.nbody, 4), dtype=np.float32)
for idx, d_i in enumerate(d):
xpos[idx] = d_i.xpos
xmat[idx] = d_i.xmat
xquat[idx] = d_i.xquat
return Updates(
jnp.asarray(final_state[:, -1]),
jnp.asarray(sensordata[:, -1]),
jnp.asarray(xpos),
jnp.asarray(xquat),
jnp.asarray(xmat),
)
# Making the return spec
return_spec = Updates(
jax.ShapeDtypeStruct(initial_state.shape, jnp.float32),
jax.ShapeDtypeStruct((threadpool.nthread, m.nsensordata), jnp.float32),
jax.ShapeDtypeStruct((threadpool.nthread, m.nbody, 3), jnp.float32),
jax.ShapeDtypeStruct((threadpool.nthread, m.nbody, 4), jnp.float32),
jax.ShapeDtypeStruct((threadpool.nthread, m.nbody, 9), jnp.float32),
)
return jax.pure_callback(
_rollout,
return_spec,
initial_state,
control,
)
def mujoco_rollout(threadpool: Rollout, spec: mujoco.MjSpec):
@jax.custom_batching.custom_vmap
def rollout(
initial_state: Array,
control: Array,
) -> Updates:
updates = mujoco_rollout_forward(
threadpool, m, d, initial_state[None, :], control[None, :]
)
return updates._replace(
fullphysicsstate=updates.fullphysicsstate[0],
sensordata=updates.sensordata[0],
)
@rollout.def_vmap
def rollout_vmap(axis_size, in_batched, initial_state, control):
def single_batch_rollout(_, args):
updates = mujoco_rollout_forward(threadpool, m, d, *args)
return None, updates
# Reshape the initial_state and control into n_threads
output_batch = Updates(True, True, True, True, True)
n_batches = initial_state.shape[0] // threadpool.nthread
initial_state = initial_state.reshape(
n_batches,
threadpool.nthread,
*initial_state.shape[1:],
)
control = control.reshape(n_batches, threadpool.nthread, *control.shape[1:])
_, final_state = jax.lax.scan(
single_batch_rollout,
None,
(initial_state, control),
length=n_batches,
)
final_state = jax.tree.map(lambda x: x.reshape(-1, *x.shape[2:]), final_state)
return final_state, output_batch
m = spec.compile()
d = [mujoco.MjData(m) for _ in range(threadpool.nthread)]
return rollout
class MuJoCoRollOut(PyTreeNode):
_rollout: Rollout = field(pytree_node=False)
@classmethod
def init(cls, spec: mujoco.MjSpec, n_threads: int):
_rollout = mujoco_rollout(Rollout(nthread=n_threads), spec)
return cls(_rollout)
def __call__(self, d: mjx.Data, action: Array, n_steps: int) -> mjx.Data:
initial_state = get_full_physics_state(d)
control = jnp.repeat(action[None, :], n_steps, axis=0)
final_state = self._rollout(initial_state, control)
return apply_updates(d, final_state)
Example of how to use it:
n_threads = 64
n_envs = 4086
spec = walking_fruitfly_v1()
m = spec.compile()
roll = MuJoCoRollOut.init(spec, n_threads)
d = jax.vmap(lambda: mjx.make_data(mjx.put_model(m)), axis_size=n_envs)()
forward_fn = jax.jit(jax.vmap(lambda d, a: roll(d, a, n_steps=10)))
# Making an action and running
action = jnp.asarray(np.random.normal(loc=0.0, scale=0.2, size=(n_envs, m.na)))
forward_fn(d, action)
A timing script shows there's some overhead, of course, but its not too bad. I'm timing about 83ms for for 16384 envs in parallel on my M3 Mac using this code, and 52ms using the raw mujoco rollout, but that one only returns the fullphysics state (which is likely the bulk of the overhead) and doesn't do the stacking. For only 512 environments its about 3 vs 2ms. So if you don't need access to stuff like xpos and xmat it would likely be even quicker! I guess writing something using Jax's FFI would be even faster, but this seems to do pretty well, and depending on what you need, you should be able to drop this more or less in your code as-is.
Really neat, thanks @GJBoth! We'll give this a spin on our side soon and will let you know how it goes.
@saran-t @yuvaltassa would you be down to OSS a snapshot of the MJX-C code on some branch so users like @hartikainen and @GJBoth can tinker with it?
Yeah I can look into this when I'm back in London.
I just wanted to gently bump this to keep it from getting lost in the noise. Any updates?