dreamerv3 icon indicating copy to clipboard operation
dreamerv3 copied to clipboard

Extremely large gradient and vanishing images when using jax.precision=float32

Open ManfredStoiber opened this issue 9 months ago • 3 comments

First of all, thank you very much for this impressive work!

When using the configuration jax.precision=float32 for training with images, I always get an extremely large gradient (model_opt_grad_norm at ~6e+8). I assume because of that, the openl image predictions become completely white. When training the dmc_walker_walk task with the dmc_vision configurations by using the train.py-script, the image_loss_mean is at about 7e+7. When using other environments, the image_loss_mean starts at about 2000-5000, but the model_opt_grad_norm stays at ~6e+8.

I'm using float32 because I sporadically get NANs during training with images when using float16.

I already tried changing the lr and clipping values, as well as the image loss scale, but without success.

Am I maybe missing any other configurations I have to change when using float32?

Thank you for your help! Best regards

image

ManfredStoiber avatar Sep 21 '23 08:09 ManfredStoiber