transformers icon indicating copy to clipboard operation
transformers copied to clipboard

DDP + gloo + gpt2 crashes

Open dakinggg opened this issue 1 year ago • 5 comments

System Info

  • transformers version: 4.27.4
  • Platform: macOS-12.6-arm64-arm-64bit (also have tested on ubuntu)
  • Python version: 3.10.9
  • Huggingface_hub version: 0.13.3
  • PyTorch version (GPU?): 1.13.1 (False) (also have tested on older torch versions)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: no
  • Using distributed or parallel set-up in script?: yes, see script

Who can help?

@ArthurZucker @younesbelkada

Information

  • [ ] The official example scripts
  • [X] My own modified scripts

Tasks

  • [ ] An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • [X] My own task or dataset (give details below)

Reproduction

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
import transformers
import multiprocessing as mp
import torch.multiprocessing as mp
import os

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def demo_basic(rank, world_size):
    print(f"Running basic DDP example on rank {rank}.")
    setup(rank, world_size)

    # create model and move it to GPU with id rank
    gpt2 = transformers.AutoModelForCausalLM.from_pretrained('gpt2')
    module = DistributedDataParallel(gpt2)

    cleanup()

def run_demo(demo_fn, world_size):
    mp.spawn(demo_fn,
             args=(world_size,),
             nprocs=world_size,
             join=True)

if __name__ == '__main__':
    world_size = 2
    run_demo(demo_basic, world_size)

gives

Running basic DDP example on rank 1.
Running basic DDP example on rank 0.
NOTE: Redirects are currently not supported in Windows or MacOs.
NOTE: Redirects are currently not supported in Windows or MacOs.
Traceback (most recent call last):
  File "/Users/danielking/github/composer/scripts/gpt2-dist.py", line 36, in <module>
    run_demo(demo_basic, world_size)
  File "/Users/danielking/github/composer/scripts/gpt2-dist.py", line 29, in run_demo
    mp.spawn(demo_fn,
  File "/Users/danielking/miniconda3/envs/composer-dev-3.10/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 240, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/Users/danielking/miniconda3/envs/composer-dev-3.10/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 198, in start_processes
    while not context.join():
  File "/Users/danielking/miniconda3/envs/composer-dev-3.10/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 160, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException: 

-- Process 1 terminated with the following error:
Traceback (most recent call last):
  File "/Users/danielking/miniconda3/envs/composer-dev-3.10/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File "/Users/danielking/github/composer/scripts/gpt2-dist.py", line 24, in demo_basic
    module = DistributedDataParallel(gpt2)
  File "/Users/danielking/miniconda3/envs/composer-dev-3.10/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 657, in __init__
    _sync_module_states(
  File "/Users/danielking/miniconda3/envs/composer-dev-3.10/lib/python3.10/site-packages/torch/distributed/utils.py", line 136, in _sync_module_states
    _sync_params_and_buffers(
  File "/Users/danielking/miniconda3/envs/composer-dev-3.10/lib/python3.10/site-packages/torch/distributed/utils.py", line 154, in _sync_params_and_buffers
    dist._broadcast_coalesced(
RuntimeError: Invalid scalar type

It looks like the attention bias was changed from torch.uint8 in transformers version 4.26.1 to torch.bool in transformers version 4.27.x. I'm not sure if I'm doing something wrong, torch has a bug, or transformers has a bug. I don't use the gloo backend much, and discovered this error from our unit tests when upgrading transformers version. Thanks for your help!

Expected behavior

DDP wrapping gpt2 works on CPU

dakinggg avatar Mar 31 '23 01:03 dakinggg

We had to change the bias to torch.bool especially because the torch.where operations are no longer supported with uint8 in the most recent versions of pytorch.

Use of uint8 masks in torch.where has been deprecated for couple years, and though it still works in pytorch eager (with a warning), support for this has been removed in torch.compile. It would be good to audit places where uint8 masks are used and replace them with bool masks.

cc @sgugger as I am not sure about the support for DDP

ArthurZucker avatar Mar 31 '23 12:03 ArthurZucker

Interesting. Does the bug persist on PyTorch 2.0? I'll ask in our channels with PyTorch about the support of bool tensors and DDP.

sgugger avatar Mar 31 '23 13:03 sgugger

The issue does persist on torch 2.0

dakinggg avatar Mar 31 '23 16:03 dakinggg

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Apr 30 '23 15:04 github-actions[bot]

I believe this remains an issue

dakinggg avatar Apr 30 '23 17:04 dakinggg

Is there a fix for this issue yet?

ApoorvaBeedu avatar May 23 '23 19:05 ApoorvaBeedu

I have pinged the PyTorch team multiple times but got no reply on this. You can try opening an issue on their repo. They basically told us to move the torch.uint8 to torch.bool because torch.uint8 won't be supported in some operations like torch.where.

sgugger avatar May 23 '23 19:05 sgugger

Any fix for this issue? I also met in the GPT-J test. I raise one issue in the pytorch github: https://github.com/pytorch/pytorch/issues/103585

tianyil1 avatar Jun 14 '23 08:06 tianyil1

Passing broadcast_buffers=False to DistributedDataParallel fixed this for me. I've opened a PR at #24326 to surface that argument to the Trainer user.

TevenLeScao avatar Jun 16 '23 18:06 TevenLeScao

I think there's two issues here:

  • GLOO doesn't support bool. This requires update in torch.distributed to get it to work:
import os

import torch
from torch import distributed as dist

def initialize_torch_distributed():
   rank = int(os.getenv('RANK', '0'))
   world_size = int(os.getenv("WORLD_SIZE", '1'))
   backend = "gloo"

   # Call the init process.
   init_method = 'tcp://'
   master_ip = os.getenv('MASTER_ADDR', 'localhost')
   master_port = os.getenv('MASTER_PORT', '6000')
   init_method += master_ip + ':' + master_port
   torch.distributed.init_process_group(
       backend=backend,
       world_size=world_size,
       rank=rank,
       init_method=init_method
   )
   return True

def main():
   initialize_torch_distributed()

   w = torch.randn(1,3) > 0 # bool tensor
   dist.broadcast(w, src=0) # Fails with `RuntimeError: Invalid scalar type`

   print(f"Sucess: {dist.get_rank()}/{dist.get_world_size()}")


if __name__ == "__main__":
   main()

I'm not completely sure, but that can probably be fixed by adding the Scalar::Bool here to cast to uint8 (or bool?): https://github.com/pytorch/pytorch/blame/dbc8eb2a8fd894fbc110bbb9f70037249868afa8/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp#L98

  • Once you have the distributed call issues (using nccl for example), you end up with an autograd errors due to DDP broadcasting your buffers. I'm still not super clear as to why it gets triggered.

For both issues, the broadcast_buffers=False can be a good workaround. The only issue is if you mix buffers that require DDP syncing, like BatchNorm.

thomasw21 avatar Jun 17 '23 15:06 thomasw21

Although I have passed the broadcast_buffers=False to DistributedDataParallel, it seems that the issue remains: image

Expect your feedback @TevenLeScao

tianyil1 avatar Jun 19 '23 05:06 tianyil1

Hey @tianyil1 , this looks like another issue to me, and I'm not seeing in my case. If you send your file here, it could be easier to run it to debug!

TevenLeScao avatar Jun 20 '23 20:06 TevenLeScao

Thanks for your feedback @TevenLeScao. The running script was the similar to the first post but added the broadcast_buffers=False to the DistributedDataParallel:

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
import transformers
import multiprocessing as mp
import torch.multiprocessing as mp
import os

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def demo_basic(rank, world_size):
    print(f"Running basic DDP example on rank {rank}.")
    setup(rank, world_size)

    # create model and move it to GPU with id rank
    gptj = transformers.AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B")
    module = DistributedDataParallel(gptj, broadcast_buffers=False)

    cleanup()

def run_demo(demo_fn, world_size):
    mp.spawn(demo_fn,
             args=(world_size,),
             nprocs=world_size,
             join=True)

if __name__ == '__main__':
    world_size = 2
    run_demo(demo_basic, world_size)

tianyil1 avatar Jun 21 '23 00:06 tianyil1

Okay there's a hack you can do:

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
import transformers
import multiprocessing as mp
import torch.multiprocessing as mp
import os
import torch

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def demo_basic(rank, world_size):
    print(f"Running basic DDP example on rank {rank}.")
    setup(rank, world_size)

    # create model and move it to GPU with id rank
    gpt2 = transformers.AutoModelForCausalLM.from_pretrained("gpt2")
    gpt2._ddp_params_and_buffers_to_ignore = [name for name, buffer in gpt2.named_buffers() if buffer.dtype == torch.bool] # This is the trick, you ask DDP to ignore all buffers that are in torch.bool because GLOO doesn't support bool.
    module = DistributedDataParallel(gpt2)

    cleanup()

def run_demo(demo_fn, world_size):
    mp.spawn(demo_fn,
             args=(world_size,),
             nprocs=world_size,
             join=True)

if __name__ == '__main__':
    world_size = 2
    run_demo(demo_basic, world_size)

Since you don't need to sync them, it should work for you. Though the best fix would be to support bool in GLOO backend.

thomasw21 avatar Jun 21 '23 09:06 thomasw21

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Jul 15 '23 15:07 github-actions[bot]