DeepSpeed icon indicating copy to clipboard operation
DeepSpeed copied to clipboard

[BUG] overflow warning needs to be different for fp16 and non-fp16

Open stas00 opened this issue 2 years ago • 11 comments

Describe the bug

This code has an issue when it is run under non-fp16 regime.

https://github.com/microsoft/DeepSpeed/blob/da84e60d98d2e90f6f2094a219c98c8b41582eb9/deepspeed/runtime/zero/stage3.py#L1837-L1842

There are no scalers under bf16/fp32. So this warning is alarming to see - we rushed to see if somehow the config was broken, but it wasn't.

It should only say the Attempted loss scale:... part under fp16.

Most likely the same applies to its counterpart in stage 1/2.

Also do you think it'd be helpful to tell the user specifically if it's Inf vs. NaN? Since NaN isn't really an overflow or does it? Perhaps one of you with a more rigorous math background knows better. I think overflow is one of many types of NaN, thus NaN isn't always on Overflow. Please correct me if I'm wrong.

The reason I'm asking this question is to help the user to know what to look for, NaNs, Infinity, else.

@tjruwase

stas00 avatar Feb 28 '23 04:02 stas00

@stas00, I am curious what happened to the training in this case, since there is no loss scaling for non-fp16. Did the subsequent iterations continue to report overflows?

I am equally fuzzy about the relationships of overflows, NaNs, and Infinity, so will ask in the team.

tjruwase avatar Feb 28 '23 05:02 tjruwase

@stas00, I am curious what happened to the training in this case, since there is no loss scaling for non-fp16. Did the subsequent iterations continue to report overflows?

We don't know yet as the training also segfaults shortly after. When we looked at the grads (Yay!!!) it was actually a mix of Infinity and NaNs. I'd be happy to share more once I have more info in a few days hopefully.

So the main part of this request is to separate the scaler overflow report (fp16) from all other overflow reports.

Plus perhaps the warning in both cases can be explicit and say that the NaN was in the grads and not loss. A NaN loss would be a disaster.

I am equally fuzzy about the relationships of overflows, NaNs, and Infinity, so will ask in the team.

Some good reading I found: https://stackoverflow.com/questions/59335027/what-is-nan-not-a-number-in-the-words-of-a-beginner

stas00 avatar Feb 28 '23 05:02 stas00

In my testing, training with stage2 bf16 also kept showing this message. Does this have any relevance to your question? Thanks.

截屏2023-03-03 19 53 30

Sanster avatar Mar 03 '23 11:03 Sanster

@stas00 or @Sanster, could you please share a simple repro to help fix this? In particular, I am interested in how this behavior is triggered for bf16.

tjruwase avatar Mar 04 '23 11:03 tjruwase

Don't have a small repro yet, if I do I will let you know.

My hypotheses is that grad clipping is not numerically stable, since we checked with the backward hook that the grads seen in the hook are large but finite e.g., 1e6 or 1e7, but once we use safe_get_full_grad we get inf.

On the other hand I think we dump safe_get_full_grad before engine.step so this is before clipping I think.

It happens every few steps in the current experiment.

Please give us a bit of time to unravel it.

I think the issue of the warning being confusing is unrelated to the cause of the overflow.

cc: @VictorSanh + @HugoLaurencon

stas00 avatar Mar 04 '23 16:03 stas00

Post #2944 , I think what remains are improvements to the error message:

  1. Distinguishing the type of overflow: nans and/or infs
  2. Identifying the source of nans/infs: forward vs backward pass

I am not sure 2 is doable since the current overflow detection mechanisms involve inspecting the gradients. Do you have any thoughts on this?

By the way, are you familiar with https://pytorch.org/docs/stable/autograd.html#torch.autograd.detect_anomaly? It identifies the forward operation that caused the failing backward operation. It could be a useful debugging tool for your case.

tjruwase avatar Mar 05 '23 00:03 tjruwase

@tjruwase , thanks for looking into this!

@HugoLaurencon and I isolated a dummy example where we would see "inf" in the gradients (and thus would trigger an overflow). The gist: https://gist.github.com/VictorSanh/44cac3c2b4118c35e0a1136afeadb561 The readme should give the commands to run. only one GPU is necessary.

It is a dummy example for which we know the gradients for a: it should be 1e8. In pure pytorch it is the case, but with zero3, it returns inf... This is pure fp32, no mixed precision involved so this is even more weird...

VictorSanh avatar Mar 06 '23 20:03 VictorSanh

so @HugoLaurencon figured it out - the culprit was communication_data_type==fp16 w/ dtype bf16 - wrong numerical range >64k is Inf. so large grads were becoming Inf during reduction.

To help future users from falling into this subtle behavior - I propose:

while warnings usually aren't seen perhaps deepspeed should emit a warning when:

  1. bf16 regime is used and communication_data_type is fp16 - as it's likely to bite the user during training with turning finite grads into infinite ones
  2. bf16 regime is used and communication_data_type is bf16 - as it's going to be lossy during reduction calls -

In both cases suggest to set communication_data_type=fp32 instead.

stas00 avatar Mar 06 '23 23:03 stas00

@VictorSanh, thanks for the repro. Can you please check if #2970 helps?

@stas00, did we agree that communication data type for fp16 should remain fp16 because for BC sake?

tjruwase avatar Mar 08 '23 16:03 tjruwase

@VictorSanh, thanks for the repro. Can you please check if #2970 helps?

Yes! added on my todo, will report back.

VictorSanh avatar Mar 08 '23 17:03 VictorSanh

@stas00, did we agree that communication data type for fp16 should remain fp16 because for BC sake?

Yes, but also because it works. It is scaled and shouldn't overflow in reduction comms.

stas00 avatar Mar 08 '23 18:03 stas00

@VictorSanh, any update on validating this? Is it okay to close?

tjruwase avatar Mar 17 '23 18:03 tjruwase

Just tested, it works like a charm (i.e. getting the numbers we are expecting)! Fantastic, thank you @tjruwase, sorry i was a bottleneck.

VictorSanh avatar Mar 17 '23 18:03 VictorSanh