Zero
Results
2
comments of
Zero
Check jaxlib version with your cuda. I installed using this code. `pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html` if you get cuda OOM error, then use this code in front of...
Check your grad_norm value, if it`s nan of inf, turn it off. Change your LR 1e-4 or 2e-4 or 1e-5. this worked for me.