dreamerv3 icon indicating copy to clipboard operation
dreamerv3 copied to clipboard

Extremely large gradient and vanishing images when using jax.precision=float32

Open ManfredStoiber opened this issue 1 year ago • 3 comments

First of all, thank you very much for this impressive work!

When using the configuration jax.precision=float32 for training with images, I always get an extremely large gradient (model_opt_grad_norm at ~6e+8). I assume because of that, the openl image predictions become completely white. When training the dmc_walker_walk task with the dmc_vision configurations by using the train.py-script, the image_loss_mean is at about 7e+7. When using other environments, the image_loss_mean starts at about 2000-5000, but the model_opt_grad_norm stays at ~6e+8.

I'm using float32 because I sporadically get NANs during training with images when using float16.

I already tried changing the lr and clipping values, as well as the image loss scale, but without success.

Am I maybe missing any other configurations I have to change when using float32?

Thank you for your help! Best regards

image

ManfredStoiber avatar Sep 21 '23 08:09 ManfredStoiber

May I ask if you have successfully trained the dreamerv3 agent. I'm curious what the final loss of each component looks like,such as image, reward or cont. When the reconstructed images are very similar, I find that the reward's prediction is not as good as it should be. I'm not sure if it also affects subsequent strategy training. Thanks for sharing your thoughts.

return-sleep avatar Nov 24 '23 04:11 return-sleep

Unfortunately not, at least not when training on images in the walker environment

ManfredStoiber avatar Dec 15 '23 01:12 ManfredStoiber

Walker always worked for me from images, regardless of precision. I've just updated the paper and code, which has a better optimizer now. Curious if this is still an issue on your end.

danijar avatar Apr 19 '24 22:04 danijar