DeepSpeed icon indicating copy to clipboard operation
DeepSpeed copied to clipboard

Enable torch.autocast with ZeRO

Open tohtana opened this issue 10 months ago • 13 comments

DeepSpeed supports mixed precision training, but the behavior is different from torch.autocast. DeepSpeed maintains parameters and gradients both in FP32 and a lower precision (FP16/BF16) (NVIDIA Apex AMP style) and computes all modules in the lower precision while torch.autocast maintains parameters in FP32 but computes only certain operators in the lower precision. This leads to differences in:

  • performance: torch.autocast needs downcast in forward/backward
  • memory usage: DeepSpeed needs more memory to keep copies of parameters and gradients in lower precision
  • accuracy: torch.autocast has a list of modules that can safely be computed in lower precision. Some precision-sensitive operators (e.g. softmax) are computed in FP32.

To align DeepSpeed's behavior with torch.autocast when necessary, this PR adds the integration with torch.autocast with ZeRO. Here is an examples of the configuration.

"torch_autocast": {
  "enabled": true,
  "dtype": "bfloat16",
  "lower_precision_safe_modules": ["torch.nn.Linear", "torch.nn.Conv2d"]
}

Each configuration works as follows:

  • enabled: Enable the integration with torch.autocast if this is set to True. You don't need to call torch.autocast in your code. The grad scaler is also applied in the DeepSpeed optimizer.
  • dtype: lower precision dtype passed to torch.autocast. Gradients for allreduce (reduce-scatter) and parameters for allgather (only for ZeRO3) of lower_precision_safe_modules are also downcasted to this dtype.
  • lower_precision_safe_modules: Downcast for allreduce (reduce-scatter) and allgather (ZeRO3) are applied only to modules specified in this list. (The precision for PyTorch operators in forward/backward follows torch.autocast's policy, not this list.) You can set names of classes with their packages. If you don't set this item, DeepSpeed uses the default list: [torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d].

Note that we only maintain FP32 parameters with this feature enabled. For consistency, you cannot enable fp16 or bf16 in DeepSpeed config.

tohtana avatar Feb 03 '25 07:02 tohtana

I have a question about lower_precision_safe_modules

https://pytorch.org/docs/stable/amp.html#torch.autocast doesn't have an option to specify lower_precision_safe_modules - why do we then put the onus on the deepspeed user? (and I'm aware that it's optional and there is a default list).

My question is - can we automatically retrieve that safe list from pytorch? Or is this because what's pytorch considers safe isn't necessarily what deepspeed considers safe?

stas00 avatar Mar 21 '25 18:03 stas00

Hi @stas00, sorry I have missed your replies.

My question is - can we automatically retrieve that safe list from pytorch? Or is this because what's pytorch considers safe isn't necessarily what deepspeed considers safe?

Great question. Ideally, we would retrieve the list directly from PyTorch. I found it in the following files:

  • https://github.com/pytorch/pytorch/blob/18a7a04c4adecda3be17dd364d48d484fd1dcdba/aten/src/ATen/autocast_mode.cpp
  • https://github.com/pytorch/pytorch/blob/18a7a04c4adecda3be17dd364d48d484fd1dcdba/aten/src/ATen/autocast_mode.h#L828

It does seem challenging to determine which module parameters are safe to use in lower precision. Hopefully, we can refine the default list for this feature based on real-world usage.

tohtana avatar Apr 14 '25 16:04 tohtana

  1. could deepspeed detect that and probably assert if ds config isn't done right?
  2. or actually should it assert if this happens because then the behavior is unsupported?

I think it should be harmless to include torch.autocast in the training loop, though we should display a warning message. Let me double-check the behavior and add that message.

One confusing aspect is that it's valid to use PyTorch’s torch.autocast while disabling "torch_autocast" in the DeepSpeed config. In this case, DeepSpeed treats it as a pure FP32 setting. The main issue is that all communication operations (e.g., gradient reduction and all-gather) will also be performed in FP32. Since some modules are computed in BF16/FP16, this leads to inefficiencies in many cases.

tohtana avatar Apr 14 '25 16:04 tohtana

One confusing aspect is that it's valid to use PyTorch’s torch.autocast while disabling "torch_autocast" in the DeepSpeed config. In this case, DeepSpeed treats it as a pure FP32 setting.

Dare I say this is not valid - nobody does fp32 training these days.

Such situations should assert IMHO.

And fp32 comms is a huge problem as well, not just because of time overhead, but because the comms buffer is 2x larger and this happens 2 times - so overall it'll consume 4x more memory for all comms. I happen to know this because I discovered this when porting Ulysses SP, where it automatically enables fp32 comms and I had huge memory spikes in backward because of that.

stas00 avatar Apr 14 '25 17:04 stas00

Dare I say this is not valid - nobody does fp32 training these days. Such situations should assert IMHO.

Good, I can an assertion to detect that torch.autocast is enabled outside of DeepSpeed but ds_config doesn't set torch_autocast's enabled. Or it might be better to automatically enable it.

And fp32 comms is a huge problem as well, not just because of time overhead, but because the comms buffer is 2x larger and this happens 2 times - so overall it'll consume 4x more memory for all comms. I happen to know this because I discovered this when porting Ulysses SP, where it automatically enables fp32 comms and I had huge memory spikes in backward because of that.

Sorry I didn't get this. I understand the comm overhead and buffer size will be 2x larger. Could you elaborate why the communication happens twice and why it leads to 4x memory?

Ulysses is actually a good example to discuss precisions. I remember we needed FP32 reduce for training stability. We need to make sure we can still have the option.

tohtana avatar Apr 14 '25 19:04 tohtana

I'm talking about peak memory usage.

if you run torch memory profiler around the reduction call you will see that it'll use reduce_bucket_size * 4 bytes * 2 copies - I think it's because one goes from 2b to 4b (1 copy) then torch/nccl makes another copy?

I measured the default reduce_bucket_size 5e8 to consume 4GB peak memory usage when comms are in fp32. and only 1GB in bf16.

since those reduction ops memory peaks happen during backward they often lead to OOM if you want to push bs/seqlen.

stas00 avatar Apr 14 '25 23:04 stas00

Hi @stas00, I tried to add the detection of nested autocast. This validation is called before the engine's forward.

I measured the default reduce_bucket_size 5e8 to consume 4GB peak memory usage when comms are in fp32. and only 1GB in bf16.

I see, I didn't know this behavior. It seems very weird that they allocate an additional buffer only for FP32, not for BF16. Perhaps this is a separate topic from this PR, but I will investigate it more when I have a chance.

tohtana avatar Apr 22 '25 21:04 tohtana

Thank you for looking into it, Masahiro. No problem doing it elsewhere.

Using torch mem profiler will be very helpful to see the reduction memory spikes

https://pytorch.org/blog/understanding-gpu-memory-1/ - it's very easy to set up - if you need help please let me know.

sfc-gh-sbekman avatar Apr 23 '25 23:04 sfc-gh-sbekman

Good, I can an assertion to detect that torch.autocast is enabled outside of DeepSpeed but ds_config doesn't set torch_autocast's enabled. Or it might be better to automatically enable it.

If it has to be on and it breaks nothing then automatically enabling it is probably a better idea to help with ease of use.

sfc-gh-sbekman avatar May 20 '25 16:05 sfc-gh-sbekman

Hi @stas, Thank you for your feedback!

If it has to be on and it breaks nothing then automatically enabling it is probably a better idea to help with ease of use.

After reviewing the design, I now feel automatically enabling it wouldn't be straightforward. This autocast feature sets some flags to parameters before the optimizer is initialized. However, we only know whether torch.autocast is enabled or not just before a forward pass call as with torch.autocast(...) is placed to wrap a forward call. Reinitializing parts of the optimizer at that point would complicate the code.

Given that, I think it’s better to throw an error with the explanation.

tohtana avatar May 23 '25 07:05 tohtana

Then assert is the way to go, Masahiro

stas00 avatar May 23 '25 16:05 stas00

Then assert is the way to go, Masahiro

Thank you @stas00, then can you approve this PR?

tohtana avatar May 24 '25 03:05 tohtana

Hmm, I can't just hit approve, that would be defeat the purpose of doing the review.

We have only discussed one small aspect of this PR, which has been resolved, but the rest of the PR I don't know and currently rushing to finish the porting of Ulysses to Hf/DS so until that is done I won't have time to do a serious review.

stas00 avatar May 27 '25 05:05 stas00

I'm running into an issue where turning on this feature results in massive grad norms, using zero 2; have you seen this before?

[deepspeed.torch_autocast]
enabled = true  # NOTE: turning this on makes grad norms explode in stage 2 (but not stage 3)
dtype = "bfloat16"

# [deepspeed.bf16]
# enabled = true  # this works properly

[deepspeed.zero_optimization]
stage = 2
allgather_partitions = true
overlap_comm = false
reduce_scatter = true
contiguous_gradients = true
stage3_prefetch_bucket_size = 0
stage3_max_live_parameters = 0
stage3_max_reuse_distance = 0
stage3_gather_16bit_weights_on_model_save = true

Grad norms were reported via model_engine.get_global_grad_norm() and also observed via safe_get_full_grad after a backward call (but before step). Stage 3 seems to have reasonable grad norms. For some reason the loss curves also don't match exactly between stage 2 and 3 (but they match exactly using deepspeed.bf16 instead)

Vervious avatar Aug 08 '25 04:08 Vervious