brax icon indicating copy to clipboard operation
brax copied to clipboard

[FEATURE] Vsys feature: massively parallel domain randomization

Open Velythyl opened this issue 1 year ago • 3 comments

Hello!

For an unrelated research project, I needed a massively parallel RL environment with domain randomization capabilities. Isaac Sim/Gym/Omniverse fit the bill, but I also needed the simulator to be differentiable w.r.t. each domain randomization parameters.

So I set out to implement DR in brax. This is research code, so it's obviously a little janky and ad-hoc. But I thought maybe the brax community could find this interesting, and perhaps (with a lot of tuning) even merge it into brax main.

Special thanks to this github issue from which I stole some code ;) here

Note that this domain randomization method is more powerful than this. With this code, we can randomize every single simulation step, if we so wish.

The summary of the implementation is simple: we just augment the simulation state to contain sys, thereby allowing every single parallel environment access to its own separate sys. Also, this enables us to resample sys according to some rule (for example, "resample every 50 steps").


Features:

The vsys wrapper allows for a vectorized sys variable that might contain different domain randomization values for each vectorized env
Domain randomization is controlled via a simple yaml file format that describes the path to a domain randomization target. Example:
link:
  inertia:
    mass:
      base: [r, r, r, r, r, r, r]
      min: [-0.5, -0.5,-0.5,-0.5,-0.5,-0.5,-0.5]
      max: [0.5, 0.5,0.5,0.5,0.5,0.5,0.5]
  constraint_ang_damping:
    min: [-1,-1,-1,1,1,1,1]
    max: [2,2,2,1.5,1,1,1]

This randomizes over the 7 links of the robot. For the mass, the base is "r", so the value is "read" from the default value defined in the URDF file. The min-max ranges are both relative to the base, so the current setup randomizes from [r-0.5, r+0.5]. For the damping, no base is given, which defaults to "r". One could also set the base to a float value. Another possible value for the base is "n" ("none"), which disables randomization for this index.

Domain randomization is differentiable (!)

For example, running a simple optax optimizer, we can obtain the true domain randomization parameters in play for a specific timestep.

Known issues:

  • Because of the need for sys to be included in the state, a few of the python type hints are broken
  • The yaml definition system is arbitrary and might not be best-practices
  • This is kind of a huge PR, so thoroughly testing all ~800 lines of code changes is bound to be tough
  • You can test the changes by looking at the script in the vsys wrapper's if __name__ == "__main__": function. Specifically, here: https://github.com/Velythyl/brax/blob/b6cab6449ba677108e37739286e0521f7c226a9e/brax/envs/wrappers/vsys.py#L553

Again, I don't expect this to be merged as-is. But perhaps the implementation might be interesting to the community, hence the reason for this PR.

Velythyl avatar Feb 15 '24 22:02 Velythyl

Hey @Velythyl I've been looking forward to this feature for a while now, thanks a lot for sharing this! I'm just curious, why did you close the PR?

lebrice avatar Feb 16 '24 05:02 lebrice

@lebrice Hey! Sorry, I realized I had some cleanup to do, and it was way past 5pm so I wanted to go home. I reopened it now.

Velythyl avatar Feb 16 '24 16:02 Velythyl

Thanks @Velythyl ! The recent comment made me just realize that maintainers hadn't commented on the PR. There were a few design decisions that went into DomainRandomizationVmapWrapper:

  1. We saw better performance when sys was not added as part of State
  2. We wanted the user to fully define the randomization strategy rather than have a schema. At HEAD, this can be done via the randomization_fn.

The cons of the impl at HEAD are that:

  1. The reset is static and stored in the wrapper, as addressed in this PR.
  2. Simple randomization strategies still require the user to write a randomization_fn

What I think would make sense to merge, is to add a wrapper with the same API as DomainRandomizationVmapWrapper, that passes in_axes and the randomized Sys PyTree values in the State, as discussed in this thread: https://github.com/google/brax/issues/446 .

btaba avatar Apr 30 '24 21:04 btaba