brax
brax copied to clipboard
[BUG] - UnexpectedTracerError in Brax PPO with Domain Randomization (mjx backend)
Hi, I've encountered an UnexpectedTracerError when using Brax's PPO implementation with the mjx backend under specific conditions:
Conditions:
- No additional
eval_envprovided. - Using a
randomization_fnthat depends on original system parameters (e.g.,sys.tree_replace({"density": sys.density * 0.9})). - Utilizing default environment wrappers.
Issue Explanation:
The issue arises in the DomainRandomizationVmapWrapper. Specifically, the _env_fn method modifies an external Python variable (env.unwrapped.sys) by assigning it a tracer (sys), causing tracer leakage. This modification affects the global state of the environment after tracing the initial reset:
class DomainRandomizationVmapWrapper(Wrapper):
"""Wrapper for domain randomization."""
def __init__(
self,
env: Env,
randomization_fn: Callable[[System], Tuple[System, System]],
):
super().__init__(env)
self._sys_v, self._in_axes = randomization_fn(self.sys)
def _env_fn(self, sys: System) -> Env:
env = self.env
env.unwrapped.sys = sys
return env
def reset(self, rng: jax.Array) -> State:
def reset(sys, rng):
env = self._env_fn(sys=sys)
return env.reset(rng)
state = jax.vmap(reset, in_axes=[self._in_axes, 0])(self._sys_v, rng)
return state
After the initial tracing, env.sys and consequently the original user environment's system become tracers, causing subsequent tracing failures (e.g., during evaluation environment creation).
Potential Fix: I've verified a simple workaround: restore the original system during reset/step operations to avoid persistent state modifications during tracing.
class DomainRandomizationVmapWrapper(Wrapper):
"""Wrapper for domain randomization."""
def __init__(
self,
env: Env,
randomization_fn: Callable[[System], Tuple[System, System]],
):
super().__init__(env)
self._sys_v, self._in_axes = randomization_fn(self.sys)
@contextmanager
def _swap_sys(self, new_sys):
env = self.env.unwrapped
old_sys = env.sys
env.sys = new_sys
try:
yield
finally:
env.sys = old_sys
def reset(self, rng: jax.Array) -> State:
def reset(sys, rng):
with self._swap_sys(sys):
out = self.env.reset(rng)
return out
state = jax.vmap(reset, in_axes=[self._in_axes, 0])(self._sys_v, rng)
return state
def step(self, state: State, action: jax.Array) -> State:
def step(sys, s, a):
with self._swap_sys(sys):
out = self.env.step(s, a)
return out
res = jax.vmap(step, in_axes=[self._in_axes, 0, 0])(self._sys_v, state, action)
return res
This ensures the original environment remains unmodified, preventing tracer leakage.
Additional Context: This issue does not surface in provided tutorial notebooks because both training and evaluation environments (duplicates) are explicitly provided. However, since the evaluation environment is optional, preventing state modifications like this is important to avoid hidden side effects.
Let me know if you need additional details. I'm new to Brax and JAX, so apologies in advance for any oversight.
Thank you!
Hi @AndruGomes13 this is expected, and I haven't had a chance to find a workaround besides the one you have (and the one in the tutorials). Lmk if you have time to take a look.
Hi @btaba , I'd be happy to take a look at this. Could you clarify what you mean by “this is expected”? Do you mean that this behavior is intentional, or just that it’s a known issue? Also, I’d appreciate your thoughts on the solution I proposed, do you think it’s viable, or is there something else I should consider?
Yes I mean that it's a known issue. Sorry I missed the proposed solution in your description, it's very clever! Have you checked that this indeed leads to proper vmap-ping over the sys (it seems like it would work but good to double check and test)? If so, I think you should make a PR!
Thanks for your PR, it actually solved my randomization UnexpectedTracerError.