brax
brax copied to clipboard
Example differentiability-leveraging algorithm parameters for legged systems?
Has anyone had any success with differentiability-leveraging learning algorithms (e.g. the built-in APG) for legged systems, like HalfCheetah or Ant? I see that there's an example piece of code for the contact-less reacher, but that's it. Does anyone have example hyperparameters for the aforementioned two systems? When I try running it, oftentimes I get NaNs/infinities in the weights. Any help would be appreciated!
I've had some success training Ant policies with APG, though it is quite slow (an hour on a TPU), and required pretty careful hyperparameter tuning. This codeblock should work if you plop it in our training notebook (be sure to import apg too!)
train_fn = functools.partial(
apg.train,
episode_length=400,
action_repeat=1,
num_envs=1024,
num_eval_envs = 128,
learning_rate=3e-3,
truncation_length=10,
log_frequency = 1000,
)
max_y = 1000
min_y = -100
xdata = []
ydata = []
times = [datetime.now()]
def progress(num_steps, metrics):
times.append(datetime.now())
xdata.append(num_steps)
ydata.append(metrics['eval/episode_reward'])
clear_output(wait=True)
plt.xlim([0, 1000])
plt.ylim([min_y, max_y])
plt.xlabel('# environment steps')
plt.ylabel('reward per episode')
plt.plot(xdata, ydata)
plt.show()
inference_fn, params, _ = train_fn(environment_fn=env_fn, progress_fn=progress)
print(f'time to jit: {times[1] - times[0]}')
print(f'time to train: {times[-1] - times[1]}')
We detailed some of the struggles with gradient-based optimization in this preprint recently, too: https://arxiv.org/abs/2111.05803v1
Let me know if you have any further questions! I'd be happy to look at any repros you have for producing NaNs---those are actually usually pretty easy to fix.
@cdfreeman-google Thanks for the response -- this seems to have worked with training ant for me as well!
Follow-up: From what I understand, APG is running a bunch of environments in parallel, each with long episode lengths (in this case, 400). The algorithm simulates forward truncation_length=10 steps in all environments in parallel by sampling from a distribution defined by the policy network 10 times sequentially, then computes gradients of -average final reward over all environments, then uses this gradient to update policy parameters. Then, this is repeated, simulating forward another 10 steps from where the previous ten left off.
Is that all right? Is there a reason that it only takes the loss as negative average final reward (as opposed to, say, discounted short-horizon reward)?
Really appreciate all the help!
Yep that's right. (It collects data in 400-step rollouts, but those are effectively a batch of disconnected 10-step chunks as far as the gradient is concerned)
Re discounted reward: no good reason--it was simpler to try, and it worked :D . Truncating effectively introduces a short horizon bias anyway, though I agree explicitly calculating the reward with discounting is a bit more principled.
Additional question I just had, is there a reason the reward plot is scaled so differently from PPO? Using PPO I seem to be able to get rewards of 5000+, while this gets sub-1000. Is it just a combination of 1) the progress fxn plotting total reward / episode and episodes being longer with PPO (I'm using 1000 vs 400 for APG) and 2) APG just not getting as high of rewards as PPO?