dreamerv3
dreamerv3 copied to clipboard
Extremely large gradient and vanishing images when using jax.precision=float32
First of all, thank you very much for this impressive work!
When using the configuration jax.precision=float32 for training with images, I always get an extremely large gradient (model_opt_grad_norm at ~6e+8). I assume because of that, the openl image predictions become completely white. When training the dmc_walker_walk task with the dmc_vision configurations by using the train.py-script, the image_loss_mean is at about 7e+7. When using other environments, the image_loss_mean starts at about 2000-5000, but the model_opt_grad_norm stays at ~6e+8.
I'm using float32 because I sporadically get NANs during training with images when using float16.
I already tried changing the lr and clipping values, as well as the image loss scale, but without success.
Am I maybe missing any other configurations I have to change when using float32?
Thank you for your help! Best regards