nanodl icon indicating copy to clipboard operation
nanodl copied to clipboard

Gradient synchronization in data-parallel trainers

Open cgarciae opened this issue 1 year ago • 1 comments

Hey, great job with nanodl!

I was just looking through the code and noticed that when in Lambda's Trainer the gradients are not being averaged across devices here:

https://github.com/HMUNACHI/nanodl/blob/18c7f8e3da3c0bbfe2df3638a5e87857ec84868d/nanodl/__src/models/lamda.py#L564-L565

Not sure if this is happening elsewhere but usually to keep the weights in sync you apply a jax.lax.pmean over the gradients before passing them to apply_gradients, e.g.

grads = jax.lax.pmean(grads, axis_name='devices')

cgarciae avatar Feb 28 '24 00:02 cgarciae