Megatron-DeepSpeed icon indicating copy to clipboard operation
Megatron-DeepSpeed copied to clipboard

Implement Gradient Noise Scale monitoring

Open ibeltagy opened this issue 2 years ago • 9 comments

Follow appendix A.1 https://arxiv.org/pdf/1812.06162.pdf to implement monitoring of gradient noise scale and add it to the tensorboard log.

ibeltagy avatar Jul 30 '21 19:07 ibeltagy

I can take a look. I've already read the appendix.

lintangsutawika avatar Aug 03 '21 08:08 lintangsutawika

I think I have found where each the parameters of from each GPU is collected. https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/7b9988146881f6eee33f69c28a92ae03e2678e42/megatron/optimizer/clip_grads.py#L48-L67

The snippet below is where the gradient norm is reduced. This would refer to formula, yes? https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/7b9988146881f6eee33f69c28a92ae03e2678e42/megatron/optimizer/clip_grads.py#L74-L82

I suppose in order to calculate formula I will have to call total_norm = max(grad.abs().max() for grad in grads_for_norm) for each GPU and then calculate noise scale. Just want to verify if my hunch is correct.

@ibeltagy @slippylolo

lintangsutawika avatar Aug 04 '21 16:08 lintangsutawika

Hmm, I could be wrong, but this all_reduce only acts across model parallelism. My understanding of Appendix A.1. is that this should be done across data parallelism instead (so actually at a higher-level than here).

Let me dig through the code a bit to see if I can find a place where this can be accessed.

slippylolo avatar Aug 04 '21 19:08 slippylolo

https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/7b9988146881f6eee33f69c28a92ae03e2678e42/megatron/model/distributed.py#L188-L218

This is where the data parallelism all_reduce occurs in the "simplified" DDP implemented by Megatron (DDP_impl in local mode). If the PyTorch DDP is used instead, then we would have to modify PyTorch code directly, or find another less dirty way to hook ourselves into this.

I feel like it's kind of dirty to hook ourselves in the low-level DDP framework for such a measurement anyway, so we should find a better solution.

slippylolo avatar Aug 04 '21 19:08 slippylolo

I think we currently don't use this, but ZeRO-DP from DeepSpeed (stage 1). I haven't verified which specific code paths it takes. It might help to step through with the debugger.

We only use TP from Megatron, but PP and DP from Deepspeed.

I probably should update the examples on how to include deepspeed zero-1. I will do that shortly.

stas00 avatar Aug 04 '21 20:08 stas00

Here are the instructions added: https://github.com/bigscience-workshop/Megatron-DeepSpeed#deepspeed-pp-and-zero-dp

stas00 avatar Aug 04 '21 21:08 stas00

So by using option --DDP-impl torch we can acquire the gradients from Megatron-DeepSpeed/megatron/model/distributed.py and can be micro batch gradient norm and the global batch gradient norm can be calculated each iteration.

I feel like even if we use Pipeline Parallelism, the difference would just be that calculating the global batch gradient norm can only be done during the backpropagation on all accumulated gradients.

What do you think @slippylolo ?

lintangsutawika avatar Aug 05 '21 13:08 lintangsutawika

I don't think this is an acceptable solution, as ZeRO-DP comes with some pretty nice savings (plus we wouldn't want to maintain two different data parallelism scheme just for gradient noise scale measurement.

I haven't had time to look into it yet, but there must be an easy way to acquire the local and global batch gradients in ZeRO-DP.

slippylolo avatar Aug 06 '21 07:08 slippylolo

Yes, I agree.

In that case, I think we can extract the gradients around here. formula would be the gradients obtained from each micro-batch and formula would be the gradient after averaging all accumulations. https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/5e3963db4a5db36e1ebdb152e2760382bdc7ef04/megatron/training.py#L441-L451

The DeepSpeed documentation explains that the gradient accumulation method is used is slightly different than what we might want. Instead of gradient reduction per micro-batch, it instead averages locally for each step and finally averages across all GPUs at the end of the accumulation sequence. https://www.deepspeed.ai/features/#smart-gradient-accumulation

lintangsutawika avatar Aug 09 '21 14:08 lintangsutawika