tevatron icon indicating copy to clipboard operation
tevatron copied to clipboard

Why to scale the loss when DDP trainning

Open Yu-Shi opened this issue 2 years ago • 5 comments

Dear Luyu,

Thank you for the great package simplifying the cumbersome DR training & evaluation process! I try using it in DDP mode and I read the code. I have a question regarding the loss:

InDenseModel class, if the model is running in DDP, you scale the loss by word_size (number of gpus used). In DenseTrainer you unscale it after the backward process. Why do you introduce the scaling & unscaling process?

Best regards, Shi

Yu-Shi avatar Mar 09 '22 13:03 Yu-Shi

To my understanding, backward, gradient clipping & weight updating seem to be based on the scaled loss, and the unscaled one is only for logging?

Yu-Shi avatar Mar 09 '22 13:03 Yu-Shi

Hi Shi,

Thanks for the interest in Tevatron!

The loss scaling is a pytorch specific thing. In pytorch, parallel collectives like all_gather are not differentiable and therefore when we gather all query/passage, each process is effectively calculating a portion of the gradient. We will rely on the final all_reduce defined in the DDP module to sync gradients. One thing is that pytorch perform mean reduce instead of sum reduce, meaning that the aggregated gradient will be divided by world_size. (This is defined in the C++ code and there's no easy handle to switch from mean to sum.) We therefore employ this workaround of loss scaling before backward, thanks to the linear nature of differentiation. Loss scaler is rescaled then for logging purpose.

luyug avatar Mar 10 '22 05:03 luyug

Thank you for your reply, that's interesting! I didn't realize that all_gather is not differentiable. I think the mechanism is like what this article describes: https://amsword.medium.com/gradient-backpropagation-with-torch-distributed-all-gather-9f3941a381f8, isn't it?

Yu-Shi avatar Mar 10 '22 07:03 Yu-Shi

Right, this is basically what it describes as "To fix this problem, we can simply calculate the loss for each GPU as Nf rather than f ."

luyug avatar Mar 10 '22 12:03 luyug

OK, thanks! It would be nicer if you describe more on it in the comments of the code :)

Yu-Shi avatar Mar 10 '22 13:03 Yu-Shi