mujoco
mujoco copied to clipboard
Error when running training_apg.ipynb: ValueError: safe_zip() argument 2 is longer than argument 1
Hi,
I'm a student trying out MJX for some projects. I was looking at training_apg.ipynb and tried running it on my computer and got this error:
File ~/github/differential-imitation/brax/brax/training/agents/apg/train.py:175, in train.<locals>.loss(policy_params, normalizer_params, env_state, key)
171 def loss(policy_params, normalizer_params, env_state, key):
172 f = functools.partial(
173 env_step, policy=make_policy((normalizer_params, policy_params)))
174 (state_h, _), (rewards,
--> 175 obs) = jax.lax.scan(f, (env_state, key),
176 (jnp.arange(horizon_length // action_repeat)))
178 return -jnp.mean(rewards), (obs, state_h)
[... skipping hidden 65 frame]
File ~/miniconda3/envs/rl/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:1115, in _partial_eval_jaxpr_nounits(jaxpr, in_unknowns, instantiate)
1113 return [*known_vals_out, *residuals]
1114 print(f"jax zip: {len(jaxpr.in_avals)}, {len(in_unknowns)}")
-> 1115 known_avals = [a for a, uk in zip(jaxpr.in_avals, in_unknowns) if not uk]
1116 jaxpr_known, _, consts_known, () = trace_to_jaxpr_dynamic(lu.wrap_init(fun), known_avals)
1117 (out_unknowns, jaxpr_unknown, res_avals), = cell # pytype: disable=bad-unpacking
ValueError: safe_zip() argument 2 is longer than argument 1
line 1114 is what I added to check the values of the zipped values, and I got this:
jax zip: 654, 654
jax zip: 2, 2
jax zip: 4, 4
jax zip: 5, 5
jax zip: 2, 2
jax zip: 4, 4
jax zip: 5, 5
jax zip: 2, 2
jax zip: 632, 632
...
jax zip: 2, 2
jax zip: 2, 2
jax zip: 3, 3
jax zip: 21, 22
I am using jax/jaxlib v0.4.25, cloned the brax repo, and reproduced this bug on both Mujoco 3.1.4 and 3.1.5.
I'm not sure how to debug this from here; any advice would be appreciated.
@Andrew-Luo1, any thoughts?
Update: I found that I had actually made some changes to the environment _init
function by changing the default solver settings. The training ran successfully after removing these changes:
mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
mj_model.opt.iterations = 6
mj_model.opt.ls_iterations = 3
mj_model.opt.jacobian = 0 # dense
These were settings that I previously found to work well with my model that I was planning on using APG for later.
However, after some testing I found that increasing mj_model.opt.iterations
causes this issue; starting from a clean notebook, if I add mj_model.opt.iterations = 2
, I get the error from above. Not sure if this is the only cases where I can get this error though.
I haven't tried differentiating through mjx.step with settings other than from 1182. Do these settings work for you? I found that with the eulerdamp flag disabled, I could stably simulate anything from a pendulum to a quadruped with high PD gain.
Thanks for the input @Andrew-Luo1, after disabling eulerdamp I'm no longer getting the zip() error! Unfortunately I'm now getting a NaN during self._generate_eval_unroll()
:
File [~/miniconda3/envs/apg/lib/python3.11/site-packages/brax/training/acting.py:125], in Evaluator.run_evaluation(self, policy_params, training_metrics, aggregate_episodes)
[122] self._key, unroll_key = jax.random.split(self._key)
[124] t = time.time()
--> [125] eval_state = self._generate_eval_unroll(policy_params, unroll_key)
[126] eval_metrics = eval_state.info['eval_metrics']
[127] eval_metrics.active_episodes.block_until_ready()
[... skipping hidden 24 frame]
File [~/miniconda3/envs/apg/lib/python3.11/site-packages/jax/_src/pjit.py:1486] in _pjit_call_impl_python(jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline, *args)
[1469]# If control reaches this line, we got a NaN on the output of `compiled`
[1470] # but not `fun.call_wrapped` on the same arguments. Let's tell the user.
[1471] msg = (f"{str(e)}. Because "
[1472] "jax_config.debug_nans.value and/or config.jax_debug_infs is set, the "
[1473] "de-optimized function (i.e., the function as if the `jit` "
(...)
[1484] "If you see this error, consider opening a bug report at "
[1485] "https://github.com/google/jax.")
-> [1486]raise FloatingPointError(msg)
FloatingPointError: invalid value (nan) encountered in jit(scan). Because jax_config.debug_nans.value and/or config.jax_debug_infs is set, the de-optimized function (i.e., the function as if the `jit` decorator were removed) was called in an attempt to get a more precise error message. However, the de-optimized function did not produce invalid values during its execution. This behavior can result from `jit` optimizations causing the invalid value to be produced. It may also arise from having nan/inf constants as outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`.
This is environment is using a slightly modified dmcontrol rodent model (to make it MJX compatible). Is this new issue due to simulation stability? I had previously gotten nans with this model when using the Newton solver for basic PPO training(hence the original solver changes), but the original issue persists if I try to change the solver settings. Any advice on what to try would be appreciated!
In my experience, ocassional nans are to be expected from simulators. I also ran into nans with PPO training. Here's three thoughts:
- Seeing how it's a floating point error, you can try using 64-bit computations like in training_apg.ipynb, at the cost of doubling the memory requirements and wall clock time. See the flags near the imports.
- You can disable the nan debug flag, then code some nan-catching logic. For example, setting nan rewards and observations to 0 and/or resetting an environment if anything in the state is nan. I found that adding the below code to the stepping function helps with non-nan PPO training with the Newton solver:
reward = sum(reward_tuple.values())
reward = jp.nan_to_num(reward)
obs = jp.nan_to_num(obs)
from jax.flatten_util import ravel_pytree
flattened_vals, _ = ravel_pytree(data)
num_nans = jp.sum(jp.isnan(flattened_vals))
done = jp.where(num_nans > 0, 1.0, done)
- The error message you're showing is from the evaluator, which doesn't have to do with training. Worst case, you can play with the num_evals parameter or comment out the evaluator code.
Hopefully these help.