Enable torch.autocast with ZeRO
DeepSpeed supports mixed precision training, but the behavior is different from torch.autocast. DeepSpeed maintains parameters and gradients both in FP32 and a lower precision (FP16/BF16) (NVIDIA Apex AMP style) and computes all modules in the lower precision while torch.autocast maintains parameters in FP32 but computes only certain operators in the lower precision.
This leads to differences in:
- performance:
torch.autocastneeds downcast in forward/backward - memory usage: DeepSpeed needs more memory to keep copies of parameters and gradients in lower precision
- accuracy:
torch.autocasthas a list of modules that can safely be computed in lower precision. Some precision-sensitive operators (e.g. softmax) are computed in FP32.
To align DeepSpeed's behavior with torch.autocast when necessary, this PR adds the integration with torch.autocast with ZeRO. Here is an examples of the configuration.
"torch_autocast": {
"enabled": true,
"dtype": "bfloat16",
"lower_precision_safe_modules": ["torch.nn.Linear", "torch.nn.Conv2d"]
}
Each configuration works as follows:
enabled: Enable the integration withtorch.autocastif this is set toTrue. You don't need to calltorch.autocastin your code. The grad scaler is also applied in the DeepSpeed optimizer.dtype: lower precision dtype passed totorch.autocast. Gradients for allreduce (reduce-scatter) and parameters for allgather (only for ZeRO3) oflower_precision_safe_modulesare also downcasted to this dtype.lower_precision_safe_modules: Downcast for allreduce (reduce-scatter) and allgather (ZeRO3) are applied only to modules specified in this list. (The precision for PyTorch operators in forward/backward followstorch.autocast's policy, not this list.) You can set names of classes with their packages. If you don't set this item, DeepSpeed uses the default list:[torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d].
Note that we only maintain FP32 parameters with this feature enabled. For consistency, you cannot enable fp16 or bf16 in DeepSpeed config.
I have a question about lower_precision_safe_modules
https://pytorch.org/docs/stable/amp.html#torch.autocast doesn't have an option to specify lower_precision_safe_modules - why do we then put the onus on the deepspeed user? (and I'm aware that it's optional and there is a default list).
My question is - can we automatically retrieve that safe list from pytorch? Or is this because what's pytorch considers safe isn't necessarily what deepspeed considers safe?
Hi @stas00, sorry I have missed your replies.
My question is - can we automatically retrieve that safe list from pytorch? Or is this because what's pytorch considers safe isn't necessarily what deepspeed considers safe?
Great question. Ideally, we would retrieve the list directly from PyTorch. I found it in the following files:
- https://github.com/pytorch/pytorch/blob/18a7a04c4adecda3be17dd364d48d484fd1dcdba/aten/src/ATen/autocast_mode.cpp
- https://github.com/pytorch/pytorch/blob/18a7a04c4adecda3be17dd364d48d484fd1dcdba/aten/src/ATen/autocast_mode.h#L828
It does seem challenging to determine which module parameters are safe to use in lower precision. Hopefully, we can refine the default list for this feature based on real-world usage.
- could deepspeed detect that and probably assert if ds config isn't done right?
- or actually should it assert if this happens because then the behavior is unsupported?
I think it should be harmless to include torch.autocast in the training loop, though we should display a warning message. Let me double-check the behavior and add that message.
One confusing aspect is that it's valid to use PyTorch’s torch.autocast while disabling "torch_autocast" in the DeepSpeed config. In this case, DeepSpeed treats it as a pure FP32 setting. The main issue is that all communication operations (e.g., gradient reduction and all-gather) will also be performed in FP32. Since some modules are computed in BF16/FP16, this leads to inefficiencies in many cases.
One confusing aspect is that it's valid to use PyTorch’s torch.autocast while disabling "torch_autocast" in the DeepSpeed config. In this case, DeepSpeed treats it as a pure FP32 setting.
Dare I say this is not valid - nobody does fp32 training these days.
Such situations should assert IMHO.
And fp32 comms is a huge problem as well, not just because of time overhead, but because the comms buffer is 2x larger and this happens 2 times - so overall it'll consume 4x more memory for all comms. I happen to know this because I discovered this when porting Ulysses SP, where it automatically enables fp32 comms and I had huge memory spikes in backward because of that.
Dare I say this is not valid - nobody does fp32 training these days. Such situations should assert IMHO.
Good, I can an assertion to detect that torch.autocast is enabled outside of DeepSpeed but ds_config doesn't set torch_autocast's enabled. Or it might be better to automatically enable it.
And fp32 comms is a huge problem as well, not just because of time overhead, but because the comms buffer is 2x larger and this happens 2 times - so overall it'll consume 4x more memory for all comms. I happen to know this because I discovered this when porting Ulysses SP, where it automatically enables fp32 comms and I had huge memory spikes in backward because of that.
Sorry I didn't get this. I understand the comm overhead and buffer size will be 2x larger. Could you elaborate why the communication happens twice and why it leads to 4x memory?
Ulysses is actually a good example to discuss precisions. I remember we needed FP32 reduce for training stability. We need to make sure we can still have the option.
I'm talking about peak memory usage.
if you run torch memory profiler around the reduction call you will see that it'll use reduce_bucket_size * 4 bytes * 2 copies - I think it's because one goes from 2b to 4b (1 copy) then torch/nccl makes another copy?
I measured the default reduce_bucket_size 5e8 to consume 4GB peak memory usage when comms are in fp32. and only 1GB in bf16.
since those reduction ops memory peaks happen during backward they often lead to OOM if you want to push bs/seqlen.
Hi @stas00, I tried to add the detection of nested autocast. This validation is called before the engine's forward.
I measured the default reduce_bucket_size 5e8 to consume 4GB peak memory usage when comms are in fp32. and only 1GB in bf16.
I see, I didn't know this behavior. It seems very weird that they allocate an additional buffer only for FP32, not for BF16. Perhaps this is a separate topic from this PR, but I will investigate it more when I have a chance.
Thank you for looking into it, Masahiro. No problem doing it elsewhere.
Using torch mem profiler will be very helpful to see the reduction memory spikes
https://pytorch.org/blog/understanding-gpu-memory-1/ - it's very easy to set up - if you need help please let me know.
Good, I can an assertion to detect that torch.autocast is enabled outside of DeepSpeed but ds_config doesn't set torch_autocast's enabled. Or it might be better to automatically enable it.
If it has to be on and it breaks nothing then automatically enabling it is probably a better idea to help with ease of use.
Hi @stas, Thank you for your feedback!
If it has to be on and it breaks nothing then automatically enabling it is probably a better idea to help with ease of use.
After reviewing the design, I now feel automatically enabling it wouldn't be straightforward. This autocast feature sets some flags to parameters before the optimizer is initialized. However, we only know whether torch.autocast is enabled or not just before a forward pass call as with torch.autocast(...) is placed to wrap a forward call. Reinitializing parts of the optimizer at that point would complicate the code.
Given that, I think it’s better to throw an error with the explanation.
Then assert is the way to go, Masahiro
Then assert is the way to go, Masahiro
Thank you @stas00, then can you approve this PR?
Hmm, I can't just hit approve, that would be defeat the purpose of doing the review.
We have only discussed one small aspect of this PR, which has been resolved, but the rest of the PR I don't know and currently rushing to finish the porting of Ulysses to Hf/DS so until that is done I won't have time to do a serious review.
I'm running into an issue where turning on this feature results in massive grad norms, using zero 2; have you seen this before?
[deepspeed.torch_autocast]
enabled = true # NOTE: turning this on makes grad norms explode in stage 2 (but not stage 3)
dtype = "bfloat16"
# [deepspeed.bf16]
# enabled = true # this works properly
[deepspeed.zero_optimization]
stage = 2
allgather_partitions = true
overlap_comm = false
reduce_scatter = true
contiguous_gradients = true
stage3_prefetch_bucket_size = 0
stage3_max_live_parameters = 0
stage3_max_reuse_distance = 0
stage3_gather_16bit_weights_on_model_save = true
Grad norms were reported via model_engine.get_global_grad_norm() and also observed via safe_get_full_grad after a backward call (but before step). Stage 3 seems to have reasonable grad norms. For some reason the loss curves also don't match exactly between stage 2 and 3 (but they match exactly using deepspeed.bf16 instead)