nanodl
nanodl copied to clipboard
Gradient synchronization in data-parallel trainers
Hey, great job with nanodl!
I was just looking through the code and noticed that when in Lambda's Trainer the gradients are not being averaged across devices here:
https://github.com/HMUNACHI/nanodl/blob/18c7f8e3da3c0bbfe2df3638a5e87857ec84868d/nanodl/__src/models/lamda.py#L564-L565
Not sure if this is happening elsewhere but usually to keep the weights in sync you apply a jax.lax.pmean
over the gradients before passing them to apply_gradients
, e.g.
grads = jax.lax.pmean(grads, axis_name='devices')