mujoco
mujoco copied to clipboard
The CG Solver in MJX dosen't support reverse-mode differentiation
I'm trying to differentiate the MJX step function via the autograd function jax.grad()
in JAX, like:
def step(vel, pos):
mjx_data = mjx.make_data(mjx_model)
mjx_data = mjx_data.replace(qvel = vel, qpos = pos)
pos = mjx.step(mjx_model, mjx_data).qpos
return pos
def loss(vel, pos):
pos = step(vel, pos)
return jnp.sum((pos - goal_pos)**2)
grad_loss = jax.jit(jax.grad(loss))
grad = grad_loss(vel, pos)
When there is only one rigid body in the scene, everthing works, but when there is a need to solve the collision, for example, a ball and a plane in the scene
XML = """
<mujoco>
<asset>
<texture name="grid" type="2d" builtin="checker" rgb1=".1 .2 .3"
rgb2=".2 .3 .4" width="300" height="300" mark="edge" markrgb=".2 .3 .4"/>
<material name="grid" texture="grid" texrepeat="2 2" texuniform="true"
reflectance=".2"/>
</asset>
<worldbody>
<geom name="ground" type="plane" pos="0 0 -.5" size="2 2 .1" material="grid" solimp=".99 .99 .01" solref=".001 1"/>
<body>
<freejoint/>
<geom size=".15" mass="1" type="sphere"/>
</body>
</worldbody>
</mujoco>
"""
Error occurs:
File "/path-to-mujoco/mjx/_src/solver.py", line 347, in cg_solve
ctx = jax.lax.while_loop(cond, body, _CGContext.create(m, d))
jax._src.source_info_util.JaxStackTraceBeforeTransformation: ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop with dynamic start/stop values. Try using lax.scan, or using fori_loop with static start/stop.
It seems the jax.lax.while()
function used when solving CG do not support dynamic condition function. How can I solve this?
I'm also trying to replace the ctx = jax.lax.while_loop(cond, body, _CGContext.create(m, d))
in mjx/_src/solver.py Line 347
with a simpler while function:
def while_loop(cond_fun, body_fun, init_val):
val = init_val
while cond_fun(val):
val = body_fun(val)
return val
ctx = while_loop(cond, body, _CGContext.create(m, d))
It works when not using jax.jit()
to complie the gradient function, but when using jax.jit()
, another error:
File "/path-to-mujoco/mjx/_src/solver.py", line 349, in while_loop
while cond_fun(val):
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function loss at /home/lvjun/Mujoco3/demo_mjx.py:55 for jit. This concrete value was not available in Python because it depends on the values of the arguments vel and pos.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
It is because the the improvement
and gradient
inf cond()
is not static?
def cond(ctx: _CGContext) -> jax.Array:
improvement = _rescale(m, ctx.prev_cost - ctx.cost)
gradient = _rescale(m, math.norm(ctx.grad))
done = ctx.solver_niter >= m.opt.iterations
done |= improvement < m.opt.tolerance
done |= gradient < m.opt.tolerance
return ~done
Is there any chance to make it supported for JIT compilation?
Hi @LyuJ1998 , this is a known issue with while_loop "while_loop is not reverse-mode differentiable because XLA computations require static bounds on memory requirements." https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html
You can change the while_loop to a scan . https://github.com/google/jax/discussions/3850
The TracerBoolConversionError
occurs because cond_fun(val)
is a traced jax array, but you're using it in a python while loop which expects a concrete value. Use a scan or a for loop
To reduce memory usage, there are some inplace operation in mjx.step. Inplace operation on intermediate matrics, such as X[0] += Y[0], will break back-propagation path. So jax.grad(mjx.step) doesn't work in mujoco 3.0.
Hi there,
Yes indeed this is by design, but poorly documented. I'll take this as motivation to add to the documentation.
tl;dr: if you would like to experiment with jax.grad()
, please update to MuJoCo 3.0.1 which now includes support for Newton solver in MJX. Newton converges quickly and for many models, a single solver iteration is sufficient. If your XML looks like this:
<option ... solver="Newton" iterations="1" ls_iterations="4">
Then we omit the jax.while()
and mjx.step
is differentiable.
The reason we don't support this for CG
is that replacing while
with scan
harms forward performance in some settings, so we currently accept this tradeoff.
Also please note that we have not investigated whether jax.grad
delivers useful gradients in this setting - I would love to hear insights from anyone that tries this.
@sfd158 not quite sure what you mean by inplace operations - jax operations are pure, they do not modify the original. See for example this documentation on jax.numpy.ndarray.at
Not sure if this is the right place to post this - Re: whether jax.grad
delivers useful gradients:
I have been playing with gradients over mjx.step (Newton solver; 1 iteration) for my Masters Thesis. Please see an implementation of the Short Horizon Actor Critic (SHAC) algorithm here.
SHAC involves learning control policies using analyical policy gradients; it augments the basic Analytical Policy Gradient (APG) algorithm with several features, such as a value function. I use jax.grad to take the gradient of the loss of an environment rollout with respect to the policy parameters, and these gradients are informative enough to make the algorithm work. This works without contact (inverted pendulum) and with contact (basic hopper).
While the gradients appear to be informative in these simple cases, I can't get quadruped control working; the jacobian of the MJX step, which is a component of the gradient of the loss wrt the policy parameters, is unstable; more on this in the README of the repo. I wonder if different simulation parameters could help here, since this issue appears to have gotten worse from MJX 3.1.1 to MJX 3.1.3.
Hi there,
Yes indeed this is by design, but poorly documented. I'll take this as motivation to add to the documentation.
tl;dr: if you would like to experiment with
jax.grad()
, please update to MuJoCo 3.0.1 which now includes support for Newton solver in MJX. Newton converges quickly and for many models, a single solver iteration is sufficient. If your XML looks like this:
<option ... solver="Newton" iterations="1" ls_iterations="4">
Then we omit the
jax.while()
andmjx.step
is differentiable.The reason we don't support this for
CG
is that replacingwhile
withscan
harms forward performance in some settings, so we currently accept this tradeoff.Also please note that we have not investigated whether
jax.grad
delivers useful gradients in this setting - I would love to hear insights from anyone that tries this.@sfd158 not quite sure what you mean by inplace operations - jax operations are pure, they do not modify the original. See for example this documentation on jax.numpy.ndarray.at
Hi erikfrey Does setting the iterations=1 impact the simulation accuracy? Thanks, Bugman