brax icon indicating copy to clipboard operation
brax copied to clipboard

Autoreset behavior

Open DavidSlayback opened this issue 3 years ago • 14 comments
trafficstars

I've been digging into Brax as a potential alternative to some modified dm_control enviornments I've been using and am really loving the speedup! That said, I feel like I've run into a major issue using the environments in RL and was looking for some guidance.

Basically, my environments are all partially-observable domains built off of "ant". A lot of the conditions are randomized per-episode (e.g., ant/target starting positions). I've been using the "create_gym_env" feature to work with my PyTorch agents, but I noticed a big potential issue.

At first glance, the AutoResetWrapper seemed to do what standard gym VectorEnvs do, but in reality, it's not really "resetting" the environments (with a new seed) but instead just setting them back to a cached first state. So the randomization of start conditions I do only applies across the whole batch of environments, and then for the entire training process, each individual environment is the same as it was before.

Is there a way to actually reset individual environments within a batch?

DavidSlayback avatar Mar 09 '22 16:03 DavidSlayback

We haven't actually encountered environments that needed more than the initial randomness you can cache into, say, 2048 environments, but your situation seems like a good test case. (see discussion here: https://github.com/google/brax/issues/167)

One sensible way of doing this is wrapping the autoresetwrapper with another wrapper that every X-number of episodes refreshes the first_qp state. But you might want to verify that you really need this extra randomness by just manually refreshing that state by calling reset

i.e., around line 105 here, you could instead do:

maybe_reset = self.reset(rng)
qp = jp.tree_map(where_done, maybe_reset['first_qp'], state.qp)

This will be slower (by a lot probably, because every step will be calling reset), but it should reveal whether the randomness is the real bottleneck

cdfreeman-google avatar Mar 16 '22 17:03 cdfreeman-google

Yeah, 2048 environments have a lot of randomness built in, but if I'm trying to solve a reasonably-sized generalized task (like a procedurally generated maze or foraging task) instead of a specific subset of that (based on seed), I feel like individual resets make more sense. Definitely understand the speed concerns, though.

Your solution seems like a reasonable approach. Another thing I was considering was building the "reset" call into the step function of the environment itself. I lose a bit of flexibility, but then I don't have to worry about external resets. With non-brax environments, I've often implemented a batched version with a "_reset_some(mask)" function.

I appreciate the advice! One last question I had was about best practices for Jax RNG. I'll use my "ant tag" environment as an example.

def reset(self, rng: jp.ndarray) -> env.State:
    rng, rng1, rng2 = jp.random_split(rng, 3)
    qpos = self.sys.default_angle() + jp.random_uniform(
        rng1, (self.sys.num_joint_dof,), -.1, .1)
    qvel = jp.random_uniform(rng2, (self.sys.num_joint_dof,), -.1, .1)
    ant_pos = jp.random_uniform(rng1, (2,), -self.cage_xy, self.cage_xy)
    qp = self.sys.default_qp(joint_angle=qpos, joint_velocity=qvel)
    pos = index_add(qp.pos, self.ant_mg, ant_pos[...,None])
    rng, tgt = self._random_target(rng, ant_pos)
    pos = jp.index_update(pos, self.target_idx, tgt)
    qp = qp.replace(pos=pos)
    info = self.sys.info(qp)
    obs = self._get_obs(qp, info)
    reward, done, zero = jp.zeros(3)
    metrics = {
        'hits': zero,
    }
    info = {'rng': rng}
    return env.State(qp, obs, reward, done, metrics, info)
  • I grab 2 new rngs, 1 for position, 1 for velocity. I draw the initial qpos and then ant xy with rng1. cage_xy is [4.5, 4.5]

Should I re-use rng1? And I've noticed anecdotally that the two numbers drawn my random_uniform for the ant are typically quite close, they don't seem like 2 randomly sampled numbers

  • I draw a target position by repeatedly splitting the rng until I get one at least the minimum distance away from the ant.

Can I do this more efficiently?

DavidSlayback avatar Mar 17 '22 16:03 DavidSlayback

I probably wouldn't reuse rng1. It really isn't that expensive to just generate another rng seed:

rng, rng1, rng2, rng3 = jp.random_split(rng, 4)

and then use rng3 for your ant_pos randomness. The random number generator is deterministic and stateless, so two calls to jp.random_uniform(rng1, blah blah) will give you the same "random" results.

I'm not sure I understand your last question. Is it that you want to generate a random location that's guaranteed to be some distance away from the ant, but in some range from there? I'd probably just parameterize that location by distance and theta, i.e.:

random_dist = jp.random_uniform(rng3, (1,), min_dist, max_dist)
random_angle =  jp.random_uniform(rng4, (1,), 0., 2.*jp.pi)
ant_pos = jp.array([random_dist*jp.cos(random_angle), random_dist*jp.sin(random_angle)])

Is that what you had in mind?

cdfreeman-google avatar Mar 17 '22 17:03 cdfreeman-google

Got it, thanks again!

For the first question, there's no weirdness with using one key to draw a vector of 2 random numbers from them same range (each is -4.5 to 4.5)? I could just be seeing weird things that would iron out if I were looking at more seeds.

For the second question, I just want to make sure that the target position is (a) within the arena boundaries and (b) at least "min_distance" from the ant. So the max distance would vary based on the ant's start position and the angle, and some angles could be entirely impossible. That's where I'm sort of stuck on resampling; ideally, I'd sample a random point once from the space of possible points after ant placement.

The actual random target sampling code is this (I added some more jumpy functions on my end, so the "while_loop" is a JIT-compatible Jax one

def _random_target(self, rng: jp.ndarray, ant_xy: jp.ndarray) -> Tuple[jp.ndarray, jp.ndarray]:
    """Returns a target location at least min_spawn_location away from ant"""
    rng, rng1 = jp.random_split(rng, 2)
    xy = jp.random_uniform(rng1, (2,), -self.cage_xy, self.cage_xy)
    minus_ant = lambda xy: xy - ant_xy
    def resample(rngxy: Tuple[jp.ndarray, jp.ndarray]) -> Tuple[jp.ndarray, jp.ndarray]:
        rng, xy = rngxy
        _, rng1 = jp.random_split(rng, 2)
        xy = jp.random_uniform(rng1, (2,), -self.cage_xy, self.cage_xy)
        return rng1, xy

    _, xy = while_loop(lambda rngxy: jp.norm(minus_ant(rngxy[1])) <= self.min_spawn_distance,
                          resample,
                          (rng1, xy))
    target_z = 0.5
    target = jp.array([*xy, target_z]).transpose()
    return rng, target

DavidSlayback avatar Mar 17 '22 19:03 DavidSlayback

Ahhhh, I see--totally misunderstood what you were doing.

First question: To be probably overly specific, there's no problem with this:

ant_pos = jp.random_uniform(rng1, (2,), -self.cage_xy, self.cage_xy)

Whereas

ant_pos_x = jp.random_uniform(rng1, (1,), -self.cage_xy, self.cage_xy)
ant_pos_y = jp.random_uniform(rng1, (1,), -self.cage_xy, self.cage_xy)

this will always have ant_pos_x=ant_pos_y.

For question 2: Yeah okay this is tricky. The resampling trick is probably the "simplest" in terms of "lines of code per unit how-hard-do-I-have-to-think-about-this".

Another option: sample the point (r, theta) randomly. One of (r, theta) and (r, theta+180) has a valid theta, and might just need to have its r projected to the boundary. This would slightly oversample boundary points (relative to a uniform sampling of "valid" points), but it's at least deterministic.

Another option: This does have an analytic solution, but it has a bunch of irritating edge cases (like which edges the minimum distance from the ant is in contact with). I'd probably just do the previous option, unless for some reason uniform sampling is super duper important.

Another option: Construct a grid of candidate points, compute a mask of valid points, and sample one of these. It has lower resolution than the other options, but would be uniformly sampling, and would converge to the right thing in the limit of lots of grid points.

cdfreeman-google avatar Mar 17 '22 20:03 cdfreeman-google

Thank you again! I realize you've got plenty of other stuff to work on even just with Brax, appreciate the thorough answers. I'll probably go with your second option.

As to the original topic of the post, would it be useful for me to profile the different reset options so that you get an idea of whether individual resets are prohibitive in the future? Also, any use for extra jumpy functions? I saw that you're trying to add them to JKTerry's Farama repo, wasn't sure where would be the best place to contribute

DavidSlayback avatar Mar 17 '22 20:03 DavidSlayback

Haha of course! Happy to help.

Yes, we'd love to have some numbers on reset efficiency!

Let me check with Erik about the fate of jumpy--I'll get back to you!

cdfreeman-google avatar Mar 17 '22 20:03 cdfreeman-google

Update: Feel free to open PRs adding jumpy functions here if you're using them in Brax!

cdfreeman-google avatar Mar 18 '22 17:03 cdfreeman-google

Update on the reset numbers:

https://colab.research.google.com/gist/DavidSlayback/bf5038ec024bb6e47568af2e2ba99c16/autoreset.ipynb#scrollTo=gazgx0KXWJfw

So I implemented a couple basic strategies using your built-in "fetch" environment just to keep the notebook simple:

  1. Original AutoReset (same "first state" for all environments that are done)
  2. "Naive" AutoReset (calls reset every timestep, replaces where done)
  3. "On Terminal" AutoReset (calls reset only on timesteps where at least one environment is done)
  4. "Cached" AutoReset (refresh "first state" every N steps, behaves like original otherwise)

I feel like I may not be timing these properly, though? I'm not seeing much difference in times beyond the time it takes to JIT

DavidSlayback avatar Mar 21 '22 22:03 DavidSlayback

I've started using Brax, and I am really enjoying it!

I just wanted to note that I've been using Brax on tasks that involve significant domain randomization, and/or curriculum development, and the auto-reset behavior tripped me up too

btnorman avatar Mar 24 '22 21:03 btnorman

So when using the AutoResetWrapper, there is no randomization, except for at the start? Doesn't the training overfit on those initial random seeds?

erwincoumans avatar May 21 '22 21:05 erwincoumans

Is there any new ideas? Following the design of gym wrappers (eg: AtariPreprocessing.noop_max), maybe we can manually apply several steps of random actions after the fake reset in AutoResetWrapper?

ZaberKo avatar Mar 13 '24 04:03 ZaberKo

Brax PPO now supports a param num_resets_per_eval if you want to randomize your init states multiple times during training:

https://github.com/google/brax/blob/main/brax/training/agents/ppo/train.py#L90

We generally don't use this as there's no overfitting when num_envs is large - but perhaps you'll find that helpful.

erikfrey avatar Mar 13 '24 05:03 erikfrey

Brax PPO now supports a param num_resets_per_eval if you want to randomize your init states multiple times during training:

https://github.com/google/brax/blob/main/brax/training/agents/ppo/train.py#L90

We generally don't use this as there's no overfitting when num_envs is large - but perhaps you'll find that helpful.

As pointed by @erwincoumans, it's essential to have some init randomness of envs during the training stage. It seems that num_resets_per_eval only controls the randomness at evaluation stage.

ZaberKo avatar Mar 18 '24 01:03 ZaberKo