Marcos Treviso
Marcos Treviso
Having a way to turn on accumulate-gradients/update-freq would be amazing for reproducibility on GPUs. What is the best approach for doing this in JAX?
Hi, Mostafa! Thank you for the quick response. I was able to adapt your code for text classification and it seems like the gradient accumulation is working fine. Since `jax.fori_loop`...
Hi! I got the following results on the test set by using a single GPU (24GB) and setting `accum_steps=batch_size`. All hyperparameters were kept intact, and the only thing that changed...