transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Make all Transformer models compatible with model parallelism

Open sgugger opened this issue 1 year ago • 25 comments

Accelerate makes it easy to load a model on multiple GPUs with device_map="auto". This in turn allows users to train model with naive model parallelism if they have several GPUs.

A problem that happens in Transformers, with model with heads (so not XxxModel but for instance XxxModelForSequenceClassification) is that the labels end up on a different device than the logits and there is a device mistmatch error.

Thankfully, there is an easy fix for that! #22535 shows how to fix this for T5 by just moving the labels to the same device as the logits they are compared to. This is a noop when the devices are the same, and fixes the issue if devices are different.

We would like help from the community to extend this to all models that support model parallelism, which are:

  • [x] BART
  • [x] BigBirdPegasus
  • [x] BLIP2
  • [x] BLOOM
  • [x] BridgeTower
  • [x] CamemBERT
  • [ ] CLIP
  • [x] CLIPSeg
  • [x] CodeGen
  • [x] Data2Vec Text
  • [x] Deit
  • [x] ESM
  • [x] GPT-2
  • [x] GPT-Neo
  • [x] GPT-NeoX
  • [x] GPT-NeoX Japanese
  • [x] GPT-J
  • [x] GPT-San
  • [ ] JukeBox
  • [ ] Lilt
  • [x] LLaMA (LlamaForSequenceClassification only)
  • [x] Longformer
  • [x] LongT5
  • [x] Luke
  • [x] M2M100
  • [x] mBART
  • [x] mT5
  • [ ] NLLB
  • [x] OPT
  • [ ] Owl-ViT
  • [x] Pix2Struct
  • [x] PLBART
  • [x] RoBERTa
  • [x] RoBERTa PreLayerNorm
  • [ ] SwitchTransformer
  • [x] T5
  • [x] Vilt
  • [x] ViT
  • [x] ViT-Hybrid
  • [ ] Whisper
  • [x] XLM-RoBERTa

If you would like to grab one of those models and apply the same fix as #22535 to all the model with heads, please leave a comment here!

sgugger avatar Apr 04 '23 13:04 sgugger

I think I can help with this Issue :)

muaid-mughrabi avatar Apr 04 '23 14:04 muaid-mughrabi

I would like to work on this issue - BART model :)

iamarunbrahma avatar Apr 04 '23 19:04 iamarunbrahma

Hi, I can take this up 🙌🏻

kausmeows avatar Apr 04 '23 19:04 kausmeows

Indeed, this fix is required for BLOOM. https://github.com/huggingface/transformers/compare/main...zsc:transformers:main (my fix is hacky and not PR-ready. Just FYI)

zsc avatar Apr 05 '23 04:04 zsc

Just to make sure does LlamaForCausalLM supports this feature already?(https://github.com/huggingface/transformers/issues/22546 ) it seems that, still there are some errors when using device_map="auto" for this task.

TerryCM avatar Apr 05 '23 06:04 TerryCM

Hi, I'd like to pick up the GPT-2 model!

mollerup23 avatar Apr 05 '23 14:04 mollerup23

Hi! I am taking this up for LlamaForSequenceClassification.

xssChauhan avatar Apr 05 '23 21:04 xssChauhan

Just to make sure does LlamaForCausalLM supports this feature already?(#22546 ) it seems that, still there are some errors when using device_map="auto" for this task.

It does (#22329). I have started seeing similar errors to #22546, but only after updating my drivers from 525 to 530, similar to https://github.com/huggingface/transformers/issues/22546#issuecomment-1498348442

(which is good news to me, I had no idea why that gpu started disappearing occasionally. It seems it can happen when that gpu is under any load, not just during training)

Edit: seems like the errors I was getting were actually caused by GPU sag. I haven't yet reproduced that exact error, but it has been reported elsewhere. It is certainly not consistent though.

kooshi avatar Apr 06 '23 17:04 kooshi

@younesbelkada @sgugger Does this fix (moving label/logit to same device) supposed to work (model parallelism) for all models (listed above)? Or, a crucial step toward it? Also, this design fix is only for pytorch model and not for jax or tf?

innat avatar Apr 07 '23 09:04 innat

I think it is supposed to work for all models listed above, as long as you are loading your model with device_map=xxx. And yes this should be for Pytorch only, though I am not really aware of how model parallelism work on TF & Jax

younesbelkada avatar Apr 07 '23 10:04 younesbelkada

I think it is supposed to work for all models listed above, as long as you are loading your model with device_map=xxx

I tried with such fix here https://github.com/huggingface/transformers/pull/22591#issuecomment-1498013324 but sadly it didn't work out. Any catch?

innat avatar Apr 07 '23 10:04 innat

@sgugger As the goal of this ticket is to enable model parallelism with easy fix, have the merged PR(s) checked on multi-gpu? I couldn't find any test script here https://github.com/huggingface/transformers/pull/22663/ regarding that .

innat avatar Apr 08 '23 02:04 innat

I would love to work with BridgeTower

shahad-mahmud avatar Apr 08 '23 11:04 shahad-mahmud

Hi. I would like to try with "Whisper"

trantuantdt avatar Apr 09 '23 13:04 trantuantdt

I'd like to claim OPT model if no one else has picked it up.

mollerup23 avatar Apr 10 '23 14:04 mollerup23

Taking this up for the remaining GPT models

mayankagarwals avatar Apr 11 '23 03:04 mayankagarwals

Hello, I just completed the GPT-J code. Just filling in the PR now.

jprivera44 avatar Apr 11 '23 22:04 jprivera44

Hello! I'd like to work in Whisper model

oscar-garzon avatar Apr 14 '23 01:04 oscar-garzon

Hi, is there any model on which I can work, please? Thanks.

abhigyan631 avatar Apr 14 '23 07:04 abhigyan631

Is there any remaining model on which I can work ? Thanks .

Tanmaypatil123 avatar Apr 17 '23 15:04 Tanmaypatil123

@sgugger Hello, can I work on the JukeBox?

JuheonChu avatar Apr 18 '23 03:04 JuheonChu

Hello @sgugger , I'd like to work on m2m100

elabongaatuo avatar Apr 18 '23 13:04 elabongaatuo

@sgugger I would love to work on CodeGen if it is unclaimed

Batese2001 avatar Apr 18 '23 18:04 Batese2001

Hi @sgugger I can work on Luke if it has not been taken

katiele47 avatar Apr 18 '23 18:04 katiele47

@sgugger I would like to work on SwitchTransformer, if not taken.

VomV avatar Apr 23 '23 18:04 VomV

@sgugger I think all transformers are covered, I have checked for others also...for example, switch transformers have parallelism implemented already. i think we can close this issue. The only pending models are clip,jukebox,owlvit, and Nllb , may be model parallelism is not applicable for some of there models

sushmanthreddy avatar Apr 25 '23 13:04 sushmanthreddy

Indeed, all models have been covered. Thanks a lot everyone!

sgugger avatar Apr 25 '23 14:04 sgugger