transformers
transformers copied to clipboard
Make all Transformer models compatible with model parallelism
Accelerate makes it easy to load a model on multiple GPUs with device_map="auto"
. This in turn allows users to train model with naive model parallelism if they have several GPUs.
A problem that happens in Transformers, with model with heads (so not XxxModel but for instance XxxModelForSequenceClassification) is that the labels end up on a different device than the logits and there is a device mistmatch error.
Thankfully, there is an easy fix for that! #22535 shows how to fix this for T5 by just moving the labels to the same device as the logits they are compared to. This is a noop when the devices are the same, and fixes the issue if devices are different.
We would like help from the community to extend this to all models that support model parallelism, which are:
- [x] BART
- [x] BigBirdPegasus
- [x] BLIP2
- [x] BLOOM
- [x] BridgeTower
- [x] CamemBERT
- [ ] CLIP
- [x] CLIPSeg
- [x] CodeGen
- [x] Data2Vec Text
- [x] Deit
- [x] ESM
- [x] GPT-2
- [x] GPT-Neo
- [x] GPT-NeoX
- [x] GPT-NeoX Japanese
- [x] GPT-J
- [x] GPT-San
- [ ] JukeBox
- [ ] Lilt
- [x] LLaMA (
LlamaForSequenceClassification
only) - [x] Longformer
- [x] LongT5
- [x] Luke
- [x] M2M100
- [x] mBART
- [x] mT5
- [ ] NLLB
- [x] OPT
- [ ] Owl-ViT
- [x] Pix2Struct
- [x] PLBART
- [x] RoBERTa
- [x] RoBERTa PreLayerNorm
- [ ] SwitchTransformer
- [x] T5
- [x] Vilt
- [x] ViT
- [x] ViT-Hybrid
- [ ] Whisper
- [x] XLM-RoBERTa
If you would like to grab one of those models and apply the same fix as #22535 to all the model with heads, please leave a comment here!
I think I can help with this Issue :)
I would like to work on this issue - BART model :)
Hi, I can take this up 🙌🏻
Indeed, this fix is required for BLOOM. https://github.com/huggingface/transformers/compare/main...zsc:transformers:main (my fix is hacky and not PR-ready. Just FYI)
Just to make sure does LlamaForCausalLM
supports this feature already?(https://github.com/huggingface/transformers/issues/22546 ) it seems that, still there are some errors when using device_map="auto"
for this task.
Hi, I'd like to pick up the GPT-2 model!
Hi! I am taking this up for LlamaForSequenceClassification
.
Just to make sure does
LlamaForCausalLM
supports this feature already?(#22546 ) it seems that, still there are some errors when usingdevice_map="auto"
for this task.
It does (#22329). I have started seeing similar errors to #22546, but only after updating my drivers from 525 to 530, similar to https://github.com/huggingface/transformers/issues/22546#issuecomment-1498348442
(which is good news to me, I had no idea why that gpu started disappearing occasionally. It seems it can happen when that gpu is under any load, not just during training)
Edit: seems like the errors I was getting were actually caused by GPU sag. I haven't yet reproduced that exact error, but it has been reported elsewhere. It is certainly not consistent though.
@younesbelkada @sgugger Does this fix (moving label/logit to same device) supposed to work (model parallelism) for all models (listed above)? Or, a crucial step toward it? Also, this design fix is only for pytorch model and not for jax or tf?
I think it is supposed to work for all models listed above, as long as you are loading your model with device_map=xxx
. And yes this should be for Pytorch only, though I am not really aware of how model parallelism work on TF & Jax
I think it is supposed to work for all models listed above, as long as you are loading your model with device_map=xxx
I tried with such fix here https://github.com/huggingface/transformers/pull/22591#issuecomment-1498013324 but sadly it didn't work out. Any catch?
@sgugger As the goal of this ticket is to enable model parallelism with easy fix, have the merged PR(s) checked on multi-gpu? I couldn't find any test script here https://github.com/huggingface/transformers/pull/22663/ regarding that .
I would love to work with BridgeTower
Hi. I would like to try with "Whisper"
I'd like to claim OPT model if no one else has picked it up.
Taking this up for the remaining GPT models
Hello, I just completed the GPT-J code. Just filling in the PR now.
Hello! I'd like to work in Whisper model
Hi, is there any model on which I can work, please? Thanks.
Is there any remaining model on which I can work ? Thanks .
@sgugger Hello, can I work on the JukeBox?
Hello @sgugger , I'd like to work on m2m100
@sgugger I would love to work on CodeGen if it is unclaimed
Hi @sgugger I can work on Luke
if it has not been taken
@sgugger I would like to work on SwitchTransformer, if not taken.
@sgugger I think all transformers are covered, I have checked for others also...for example, switch transformers have parallelism implemented already. i think we can close this issue. The only pending models are clip,jukebox,owlvit, and Nllb , may be model parallelism is not applicable for some of there models
Indeed, all models have been covered. Thanks a lot everyone!