brax icon indicating copy to clipboard operation
brax copied to clipboard

Best Practice for Passing/Storing Training Progress for Curriculum Learning in Brax

Open hukz18 opened this issue 1 year ago • 2 comments

Hi Brax team,

I’m working on a reinforcement learning project using Brax to train a PPO agent and I’m trying to implement curriculum learning by adjusting the environment's difficulty dynamically based on the training progress (e.g., current_steps or number of episodes). My goal is to pass this information to the environment during training so that I can change certain parameters (like gravity, object mass, etc.) as the agent progresses.

I’ve thought of a solution where I modify the training code to pass the current training progress into the environment’s reset function. Here’s a simplified example of what I have in mind:

reset_fn = jax.jit(jax.vmap(lambda x: env.reset(x, current_step)))

However, this requires modifying the reset_fn in the training loop (brax/training/agents/ppo/train.py) to pass the training progress manually. And I also need to modify all the reset functions of the wrappers to allow the current_step to be passed into the reset function.

I've also tried to simple store a scalar value in the environment like self.num_episodes = 0, and call self.num_episodes = self.num_episodes + 1 in the reset function, unfortunately, this value never actually changes despite the reset calls. So I wonder if there's a way to achieve this without changing the training code of Brax itself.

Question: Is there a better practice for passing or storing training progress information (like current_steps) in Brax for curriculum learning? Specifically:

Is modifying the training code the best approach, or can this be handled more elegantly by the environment itself? Can we store or retrieve the training progress (e.g., current_steps) in the environment without needing to modify the reset function directly? I’d appreciate any advice or best practices you can suggest for implementing this kind of feature in Brax.

Thanks for your help!

hukz18 avatar Oct 19 '24 05:10 hukz18

Hi,

Any chance you figured it out? Thanks!

juhorsch avatar Nov 26 '24 10:11 juhorsch

Hi, I ended up modifying the step function AutoResetWrapper and adding an episode_num item to keep track of the number of episodes for each environment, the modified wrapper looks like this: inside the step function of AutoResetWrapper:

    state.info['episode_num'] = jp.where(
        state.done, state.info['episode_num'] + 1, state.info['episode_num']
    )

and add the 'episode_num' key to the state dict elsewhere. You can also track the total number of environment steps similarly. However, this walkaround can't keep track of the progress inside the environment class, so I'd like to keep the issue open for now.

Hope it helps!

hukz18 avatar Nov 27 '24 00:11 hukz18

This might be helpful https://github.com/google-deepmind/mujoco_playground/commit/417b030d53ac735da5fb30ad423c8e7bc7bc786b

btaba avatar Sep 05 '25 05:09 btaba