brax
brax copied to clipboard
Brax Halfcheetah Exploding Gradients
Hi,
I'm working on a project that uses differentiable dynamics. However, for the task halfcheetah, I'm having problems with the gradient explosion. I have created a repo to reproduce this problem.
The problem can be reproduced by using the official implementation analytical policy gradient apg.py with official reward function. The only thing I changed is to print out the gradient norm before clipping.
To reproduce
python apg.py
Environment
python 3.8
brax 0.0.12
jax 0.3.5
jaxlib 0.3.5+cuda11.cudnn82
nvidia-smi
NVIDIA-SMI 510.54 Driver Version: 510.54 CUDA Version: 11.6`
Gradient norm from Halfcheetah
grad_raw [inf]
grad_raw [inf]
grad_raw [inf]
grad_raw [3.7764926e+18]
grad_raw [inf]
grad_raw [inf]
Gradient norm from ant
grad_raw [1.7340995]
grad_raw [2.4045153]
grad_raw [2.8107145]
grad_raw [1.8724597]
grad_raw [3.0794723]
grad_raw [2.4992204]
Could you plot the gradient norm as a function of trajectory length for, say, a random policy? We've seen this kind of thing before, and it usually reduces to chaotic/unstable dynamics in the system, so you may need to introduce a truncation length to get stable behavior.