transformers
transformers copied to clipboard
Detect Accelerate's DeepSpeed level 3 Env Vars and warn if synced_gpus is False
Feature request
If ACCELERATE_DEEPSPEED_ZERO_STAGE
== 3 and generate is called without synced_gpus
, it would be reasonable to warn the user that if they're doing a distributed call to generate with a deepspeed model, they need to give generate the synced_gpus
arguments.
Motivation
Background
Deepspeed level 3 shards the parameters, so it requires that model.forward
be called the same amount of times on each process even at inference time, so the weights can be moved around in time.
model.forward
is called for each token generated at generation time. If a process stops generating before other processes, Deepspeed level 3 breaks because model.forward
isn't called in processes where generation is over. That's why the synced_gpus
argument is present in model.generate
, the model.forward
function keeps getting called until all processes are done generating.
Accelerate Has Env Vars that Indicate Stage 3
When using Deepspeed, accelerate has an env var called ACCELERATE_DEEPSPEED_ZERO_STAGE
that contains the level. While ACCELERATE_DEEPSPEED_ZERO_STAGE
being set to 3 doesn't guarantee that the model is being called is distributed, it is a pretty big indication in practice, and it would be reasonable to give a warning if model.generate
(and possibly model.greedy_search
etc) are called without synced_gpus
, as new users will probably not know about this.
If there is a way for model.generate
to know in a more reliable way if the model is distributed with Deepspeed level 3, then that could be used to warn the user as well ofc.
Your contribution
I can do it, but for these nuanced, low coding qty things, you folks are probably better placed than me.
cc @stas00 and @pacman100
Totally. Thank you for bringing it up, @JulesGM
The API for checking this situation is already available and is being used in the HF Trainer:
https://github.com/huggingface/transformers/blob/bec075612a293a66022937f65ba0c0df25224d29/src/transformers/trainer_seq2seq.py#L180-L188
For DIY integration we can
- document it here: https://huggingface.co/docs/transformers/main/main_classes/deepspeed#nontrainer-deepspeed-integration
- and add an assert inside
generate
if it is called w/o this flag and WORLD_SIZE>1 and zero3. No warnings please - nobody sees those. (need to think how to check world_size insidegenerate
but checking for deepspeed first will enable a definite use oftorch.distributed.get_world_size()
so should be easy).
Would you like to work on that, @JulesGM? I'd be happy to support you or I might find time to do it myself some time later. Totally up to you.
That's great to hear Stas.
Honestly I'm kind of working night and day for my thesis deadline right now, so if you want to do it, it would be much appreciated.
Thank you for letting me know your preference, please try this PR and let me know if it solves the problem for you, @JulesGM
https://github.com/huggingface/transformers/pull/22242
I decided to just set it automatically if it wasn't set.
The docs were already correct, so no need to change them.