brax
brax copied to clipboard
Subsume part of System inside State; EDIT: Or add Options to reset
For domain randomization it is not particularly easy to vmap
over different System
values. For example the gravity
values, or the elasticity
. Preferably you should be able to do this in env.reset
but right now this is not possible as self.sys
is a global variable in the Env namespace.
Right now my hacky workaround is to Mock the Brax environment with my custom PyTree-like dataclass so I can modify the env.sys
values in a functionally pure way inside the reset
function.
It would be nice if brax could expose part of the sys
dict/ namespace as a pure argument to env.reset
and env.step
(e.g., as part of the state).
Wanted to add an example of another workaround: https://github.com/automl/CARL.
In this library for meta-RL, instead of batching environments on the GPU which Brax should support, the CARL-brax environments create VectorizedWrappers from Gymnasium in order to run multiple System
variations simultaneously. Which kind of defeats the purpose of GPU parallelization....
Hi @joeryjoery , I believe we considered passing around sys as part of the env state, but IIRC we managed to squeeze out better performance using the current implementation.
https://github.com/google/brax/blob/a89322496dcb07ac5a7e002c2e1d287c8c64b7dd/brax/envs/wrappers/training.py#L199
Feel free to implement a version of the base env class and wrapper which passes the sys in a functional way (e.g. as part of the state.info
). If you manage to get the same training performance out of it, please send it our way!
Hi @joeryjoery , I believe we considered passing around sys as part of the env state, but IIRC we managed to squeeze out better performance using the current implementation.
https://github.com/google/brax/blob/a89322496dcb07ac5a7e002c2e1d287c8c64b7dd/brax/envs/wrappers/training.py#L199
Feel free to implement a version of the base env class and wrapper which passes the sys in a functional way (e.g. as part of the
state.info
). If you manage to get the same training performance out of it, please send it our way!
Hey thanks for the reply. A big obstacle right now in trying to implement something like this is that the pipeline.init
and pipeline.step
functions are quite rigid. They only receive self, q, qd, _debug
as arguments.
So I'm trying to work around this by doing dependency injection for self
by converting it into a PyTree such that I can do jax-transforms on pipeline.init
etc.. But mocking this object is causing quite a few problems since I'm running into unforeseen dependencies. For this reason I think this approach is not great as this will definitely lead to problems later on.
@btaba Could the pipeline.init
and pipeline.apply
functions perhaps be extended to receive an optional options
dictionary? This would require the API to propagate the options in reset and step from wrappers to base (i.e., like the Gymnasium implementation).
In principle, if these are none then the performance stays the same, and if I want to provide it with options then I can wrap the pipeline
module with my custom function that modifies the self.sys
.
What do you think?
Hi @joeryjoery ,
I'm not quite following why you want to add extra args to pipeline.init
and pipeline.step
. Does something like this not work: jax.vmap(pipeline.init, in_axes=[custom_in_axes, None, None])(sys, q, qd)
?
Hey, yes this works. But it's not the problem.
The issue is that I have no easy way to propagate sys
variations to that point (at least not in a way that is jittable). So for example, the Ant
environment has a reset
which looks something like this,
def reset(self, rng: jax.Array) -> State:
"""Resets the environment to an initial state."""
rng, rng1, rng2 = jax.random.split(rng, 3)
...
pipeline_state = self.pipeline_init(q, qd)
obs = self._get_obs(pipeline_state)
...
Now suppose I want to wrap Ant
I do not have direct access to the self.pipeline_init
call. So I cannot modularly jax.vmap(pipeline.init, ...
A way to solve this is to allow options, for example,
def reset(self, rng: jax.Array, *, options: dict | None = None) -> State:
"""Resets the environment to an initial state."""
rng, rng1, rng2 = jax.random.split(rng, 3)
...
pipeline_state = self.pipeline_init(q, qd, options=options) # Pass along here
obs = self._get_obs(pipeline_state)
...
In this way, I can wrap env._pipeline
with a function like,
my_env.pipeline_init = my_wrapped_init
def my_wrapped_init(self, q, qd, *, options: dict | None = None):
sys = self.sys
if options is not None:
variations = some_sampling_function(options) # returns dict
sys = self.sys.replace(**variations)
return jax.vmap(self._pipeline.init, in_axes=(0, None, None, None))(sys, q, qd, self._debug)
return self._pipeline.init(self.sys, q, qd, self._debug)
Comments and questions on the proposed changes:
[1] Subsume part of System
inside State
: You can do this already by adding System
to state.info
, and re-writing your env code to use state.info['sys']
instead of self.sys
. How performant is that implementation for RL workloads? Then we can discuss a potential API change
[2] Add Options to reset: Strong preference here to add your logic to a wrapper, and to split out the vmap case from the non-vmap case into distinct wrappers. It looks like your proposal is similar to the DomainRandomizationVmapWrapper
except you want to do the sys.replace
at pipeline.init
/pipeline.step
time? Does this mean that the env.reset
and env.step
logic won't be accessing the same randomized version of sys
?
Hey thanks a lot for continuing the discussion.
TLDR; I was overthinking this, and the easy solution is indeed a slight modification of DomainRandomizationVmapWrapper
.
-
The problem with the current
DomainRandomizationVmapWrapper
is that the randomization is done in the__init__
and not in thereset
. If I want to resample variations at every call toreset
I instead have to reinstantiate the class, which would mean recompilingreset
andstep
which is costly. -
What I did now is make
randomization_fn
dependent on a random key and call it insidereset
, the sampled variations are then replaced insideSystem
and stored insideState.info
. These only contain the varied fields so that we don't redundantly pass around data.
In my implementation I also do not include vmap
as I think it is much easier to just vmap
over the DomainRandomization
wrapper. I have not tested performance, but the code is much more readable.
This is what I propose:
class DomainRandomization(brax.envs.Wrapper):
"""Wrapper for Procedural Domain Randomization."""
def __init__(
self,
env: Env,
randomization_fn: Callable[[System, jax.Array], System]
):
super().__init__(env)
self.randomization_fn = randomization_fn
def env_fn(self, sys: System) -> Env:
env = self.env
env.unwrapped.sys = sys
return env
def reset(self, rng: jax.Array) -> State:
key_reset, key_var = jax.random.split(rng)
sys = self.env.unwrapped.sys
variations = self.randomization_fn(sys, key_var)
new_sys = sys.replace(**variations)
new_env = self.env_fn(new_sys)
state = new_env.reset(key_reset)
state = state.replace(info=state.info | {'sys_var': variations})
return state
def step(self, state: State, action: jax.Array) -> State:
variations = state.info['sys_var']
sys = self.env.unwrapped.sys
new_sys = sys.replace(**variations)
new_env = self.env_fn(new_sys)
state = new_env.step(state, action)
state = state.replace(info=state.info | {'sys_var': variations})
return state
example usage,
def viscosity_randomizer(system: System, key: jax.Array) -> dict[str, Any]:
return {'viscosity': jax.random.uniform(key, system.viscosity.shape)}
env = envs.create(
env_name='ant',
episode_length=1000,
action_repeat=1,
auto_reset=True,
batch_size=None,
)
wrap = DomainRandomization(env, viscosity_randomizer)
s0 = jax.jit(wrap.reset)(jax.random.key(0))
s1 = jax.jit(wrap.reset)(jax.random.key(321))
print(s0.info['sys_var'], s1.info['sys_var'])
>> {'viscosity': Array(0.10536897, dtype=float32)} {'viscosity': Array(0.3906865, dtype=float32)}
print(w.unwrapped.sys.viscosity)
>> Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
print(w.default_sys.viscosity)
>> 0.0
Or composing with the VmapWrapper
,
sbatch = jax.jit(brax.envs.wrappers.training.VmapWrapper(wrap).reset)(
jax.random.split(jax.random.key(0), 5)
)
print(sbatch.info['sys_var'])
>> {'viscosity': Array([0.6306313 , 0.5778805 , 0.64515114, 0.95315635, 0.24741197], dtype=float32)}
It's not really easy to show that this implementation works here, but if you visualize the results using the code shown in the Colab, you can see that it indeed randomizes the System
variables per random key.
https://colab.research.google.com/github/google/brax/blob/main/notebooks/training.ipynb#scrollTo=4hHuDp53e4VJ
I also haven't tested performance for RL training. But it's guaranteed faster than using the current DomainRandomizationVmapWrapper
due to its non-pure implementation for randomization_fn
if your goal is to randomize at every reset
call.
Hi @joeryjoery
I think we tried a version of this implementation. A few comments:
[1] Can you update your impl to make it work for nested fields in sys
? You can probably use tree_replace
[2] IIRC passing these extra vars in the info
were costly for an RL workload. Can you compare performance with your current version vs. the version at HEAD to see where we're at, and randomize a few more parameters (esp. ones that scale with nv
nq
ngeom
)? Maybe try this on humanoid. So you'd potentially be passing (batch_size, ngeom
) parameters in the state.info
FWIW, the impl at HEAD, despite creating a static batch of sys
, is enough for sim2real transfer on a quadruped. You can also do multiple resets in training like here (if you're concerned about the static part):
https://github.com/google/brax/blob/e91772bd70ed310476f91b84c4bb8a970e77b3e3/brax/training/agents/ppo/train.py#L418-L431
Hey!
For 1) I was working on something like this, but didn't quite finish today, will update it later. What do you mean with tree_replace
is it a private brax api? I was more thinking along the lines of mocking System
with a nested dictionary.
For 2), I don't think there is a way around this, we are passing around more data. If the variations are small (like just the viscosity or gravity), then I'd imagine that this is negligible really, but this can grow yes for something like Humanoid and mass
or geoms
variations. Though, there are some optimizations here I'd imagine.
I'm not suggesting that the other DomainRandomizationVmapWrapper
is wrong, if this works well for sim2real that's amazing.
However, for me, I'm specifically looking at fulfiling my research assumptions as well as I can. This assumes random environments at every sampled trajectory, which makes learning a good policy also severely more difficult. Also, In my experiments the data-collection is rarely the bottleneck and moreso the learner I've found (at least for my very specific use-case; meaning PPO with a recurrent network architecture that also does internal matrix inversions).
If I find the time I'll try run the default agent with the current domain-randomization and the one I posted.
Hi @joeryjoery , tree_replace
can be found here:
https://github.com/google/brax/blob/f9a4d73181d699db0fa38b07c5a651f5dc8ee231/brax/base.py#L114
Thanks for the context on [2], I recommend using your own wrapper (for ensuring sampling a new system for every trajectory), looks like you're pretty close to a more general version with the implementation above! Let us know if you have any trouble and please feel free to share any findings (or open a PR)