transformers
transformers copied to clipboard
DDP + gloo + gpt2 crashes
System Info
-
transformers
version: 4.27.4 - Platform: macOS-12.6-arm64-arm-64bit (also have tested on ubuntu)
- Python version: 3.10.9
- Huggingface_hub version: 0.13.3
- PyTorch version (GPU?): 1.13.1 (False) (also have tested on older torch versions)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: no
- Using distributed or parallel set-up in script?: yes, see script
Who can help?
@ArthurZucker @younesbelkada
Information
- [ ] The official example scripts
- [X] My own modified scripts
Tasks
- [ ] An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - [X] My own task or dataset (give details below)
Reproduction
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
import transformers
import multiprocessing as mp
import torch.multiprocessing as mp
import os
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def demo_basic(rank, world_size):
print(f"Running basic DDP example on rank {rank}.")
setup(rank, world_size)
# create model and move it to GPU with id rank
gpt2 = transformers.AutoModelForCausalLM.from_pretrained('gpt2')
module = DistributedDataParallel(gpt2)
cleanup()
def run_demo(demo_fn, world_size):
mp.spawn(demo_fn,
args=(world_size,),
nprocs=world_size,
join=True)
if __name__ == '__main__':
world_size = 2
run_demo(demo_basic, world_size)
gives
Running basic DDP example on rank 1.
Running basic DDP example on rank 0.
NOTE: Redirects are currently not supported in Windows or MacOs.
NOTE: Redirects are currently not supported in Windows or MacOs.
Traceback (most recent call last):
File "/Users/danielking/github/composer/scripts/gpt2-dist.py", line 36, in <module>
run_demo(demo_basic, world_size)
File "/Users/danielking/github/composer/scripts/gpt2-dist.py", line 29, in run_demo
mp.spawn(demo_fn,
File "/Users/danielking/miniconda3/envs/composer-dev-3.10/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 240, in spawn
return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
File "/Users/danielking/miniconda3/envs/composer-dev-3.10/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 198, in start_processes
while not context.join():
File "/Users/danielking/miniconda3/envs/composer-dev-3.10/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 160, in join
raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException:
-- Process 1 terminated with the following error:
Traceback (most recent call last):
File "/Users/danielking/miniconda3/envs/composer-dev-3.10/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
fn(i, *args)
File "/Users/danielking/github/composer/scripts/gpt2-dist.py", line 24, in demo_basic
module = DistributedDataParallel(gpt2)
File "/Users/danielking/miniconda3/envs/composer-dev-3.10/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 657, in __init__
_sync_module_states(
File "/Users/danielking/miniconda3/envs/composer-dev-3.10/lib/python3.10/site-packages/torch/distributed/utils.py", line 136, in _sync_module_states
_sync_params_and_buffers(
File "/Users/danielking/miniconda3/envs/composer-dev-3.10/lib/python3.10/site-packages/torch/distributed/utils.py", line 154, in _sync_params_and_buffers
dist._broadcast_coalesced(
RuntimeError: Invalid scalar type
It looks like the attention bias was changed from torch.uint8
in transformers
version 4.26.1
to torch.bool
in transformers
version 4.27.x
. I'm not sure if I'm doing something wrong, torch has a bug, or transformers has a bug. I don't use the gloo backend much, and discovered this error from our unit tests when upgrading transformers
version. Thanks for your help!
Expected behavior
DDP wrapping gpt2 works on CPU
We had to change the bias to torch.bool
especially because the torch.where
operations are no longer supported with uint8
in the most recent versions of pytorch.
Use of uint8 masks in torch.where has been deprecated for couple years, and though it still works in pytorch eager (with a warning), support for this has been removed in torch.compile. It would be good to audit places where uint8 masks are used and replace them with bool masks.
cc @sgugger as I am not sure about the support for DDP
Interesting. Does the bug persist on PyTorch 2.0? I'll ask in our channels with PyTorch about the support of bool tensors and DDP.
The issue does persist on torch 2.0
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
I believe this remains an issue
Is there a fix for this issue yet?
I have pinged the PyTorch team multiple times but got no reply on this. You can try opening an issue on their repo.
They basically told us to move the torch.uint8 to torch.bool because torch.uint8 won't be supported in some operations like torch.where
.
Any fix for this issue? I also met in the GPT-J test. I raise one issue in the pytorch github: https://github.com/pytorch/pytorch/issues/103585
Passing broadcast_buffers=False
to DistributedDataParallel
fixed this for me. I've opened a PR at #24326 to surface that argument to the Trainer user.
I think there's two issues here:
- GLOO doesn't support
bool
. This requires update intorch.distributed
to get it to work:
import os
import torch
from torch import distributed as dist
def initialize_torch_distributed():
rank = int(os.getenv('RANK', '0'))
world_size = int(os.getenv("WORLD_SIZE", '1'))
backend = "gloo"
# Call the init process.
init_method = 'tcp://'
master_ip = os.getenv('MASTER_ADDR', 'localhost')
master_port = os.getenv('MASTER_PORT', '6000')
init_method += master_ip + ':' + master_port
torch.distributed.init_process_group(
backend=backend,
world_size=world_size,
rank=rank,
init_method=init_method
)
return True
def main():
initialize_torch_distributed()
w = torch.randn(1,3) > 0 # bool tensor
dist.broadcast(w, src=0) # Fails with `RuntimeError: Invalid scalar type`
print(f"Sucess: {dist.get_rank()}/{dist.get_world_size()}")
if __name__ == "__main__":
main()
I'm not completely sure, but that can probably be fixed by adding the Scalar::Bool here to cast to uint8 (or bool?): https://github.com/pytorch/pytorch/blame/dbc8eb2a8fd894fbc110bbb9f70037249868afa8/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp#L98
- Once you have the distributed call issues (using nccl for example), you end up with an autograd errors due to DDP broadcasting your buffers. I'm still not super clear as to why it gets triggered.
For both issues, the broadcast_buffers=False
can be a good workaround. The only issue is if you mix buffers that require DDP syncing, like BatchNorm.
Although I have passed the broadcast_buffers=False
to DistributedDataParallel
, it seems that the issue remains:
Expect your feedback @TevenLeScao
Hey @tianyil1 , this looks like another issue to me, and I'm not seeing in my case. If you send your file here, it could be easier to run it to debug!
Thanks for your feedback @TevenLeScao. The running script was the similar to the first post but added the broadcast_buffers=False
to the DistributedDataParallel
:
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
import transformers
import multiprocessing as mp
import torch.multiprocessing as mp
import os
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def demo_basic(rank, world_size):
print(f"Running basic DDP example on rank {rank}.")
setup(rank, world_size)
# create model and move it to GPU with id rank
gptj = transformers.AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B")
module = DistributedDataParallel(gptj, broadcast_buffers=False)
cleanup()
def run_demo(demo_fn, world_size):
mp.spawn(demo_fn,
args=(world_size,),
nprocs=world_size,
join=True)
if __name__ == '__main__':
world_size = 2
run_demo(demo_basic, world_size)
Okay there's a hack you can do:
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
import transformers
import multiprocessing as mp
import torch.multiprocessing as mp
import os
import torch
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def demo_basic(rank, world_size):
print(f"Running basic DDP example on rank {rank}.")
setup(rank, world_size)
# create model and move it to GPU with id rank
gpt2 = transformers.AutoModelForCausalLM.from_pretrained("gpt2")
gpt2._ddp_params_and_buffers_to_ignore = [name for name, buffer in gpt2.named_buffers() if buffer.dtype == torch.bool] # This is the trick, you ask DDP to ignore all buffers that are in torch.bool because GLOO doesn't support bool.
module = DistributedDataParallel(gpt2)
cleanup()
def run_demo(demo_fn, world_size):
mp.spawn(demo_fn,
args=(world_size,),
nprocs=world_size,
join=True)
if __name__ == '__main__':
world_size = 2
run_demo(demo_basic, world_size)
Since you don't need to sync them, it should work for you. Though the best fix would be to support bool in GLOO backend.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.