brax icon indicating copy to clipboard operation
brax copied to clipboard

Running different environments in parallel

Open jbgaya opened this issue 3 years ago • 2 comments
trafficstars

Hi and thanks for this great simulator I would like to know if it is possible to run multiple environments in parallel (for example Halfcheetah with different gravity coefficients). From what I understand, it is quite difficult to do so because the brax.System() of an environment is fixed and hard to change.

Ideally the step() function of an environment would take not only the state and action, but also a system as an input. But I don't think it is feasible to vectorize that object in jax, right ?

Any ideas ?

jbgaya avatar Jan 03 '22 12:01 jbgaya

Hi @jbgaya . System is a pytree:

https://github.com/google/brax/blob/main/brax/physics/system.py#L33

So in theory you could vmap over multiple systems. For example this code runs:

sys1 = envs.create('ant').sys
sys2 = copy.copy(sys1)
sys2.integrator.dt = 0.002
sys = jax.tree_map(lambda x, y: jnp.stack([x,y]), *[sys1, sys2])

But it won't work out of the box as default_qp won't produce the right results - we'd have to change the way System allocates internal fields.

That said, jax executes asynchronously - so have you tried just running multiple step() functions from multiple different brax.System() instances serially and then stacking the results? I would expect that to work OK, although it may be slow to JIT if you're doing hundreds or thousands of Systems.

erikfrey avatar Jan 06 '22 17:01 erikfrey

That said, jax executes asynchronously - so have you tried just running multiple step() functions from multiple different brax.System() instances serially and then stacking the results? I would expect that to work OK, although it may be slow to JIT if you're doing hundreds or thousands of Systems.

Indeed this is the easiest option since acquisition is fast. Will try. Thanks !

jbgaya avatar Jan 10 '22 10:01 jbgaya