fairseq2 icon indicating copy to clipboard operation
fairseq2 copied to clipboard

`freqs` in the updated Rotary encoder is of type `torch.complex64` which is not supported in DistributedDataParallel

Open chaoweihuang opened this issue 2 years ago • 1 comments

Describe the bug: The updated RotaryEncoder implementation introduced in this PR changes the dtype of freqs to torch.complex64. However, complex tensors are not supported by pytorch's DistributedDataParallel. Specifically it will error out during _sync_module_states in DDP's __init__.

Describe how to reproduce: Here's a simple script to reproduce the error. Make sure that you're in a GPU-enabled environment.

from fairseq2.models.llama import load_llama_model
import os
import torch
import torch.nn as nn
import torch.distributed as dist


rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
dist.init_process_group(
    backend="nccl",
    init_method="env://",
    rank=rank,
    world_size=world_size
)
torch.cuda.set_device(rank)

model = load_llama_model("llama_7b", device=torch.device("cuda"))
ddp_model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])

And run with torchrun --nproc_per_node 1 test.py will produce the following error

Traceback (most recent call last):
  File "/mnt/fsx-home/chaoweihuang/fairseq2-public-fork/test.py", line 19, in <module>
    ddp_model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])
  File "/fsx-ust/chaoweihuang/miniconda3/envs/seamless-fairseq2/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 676, in __init__
    _sync_module_states(
  File "/fsx-ust/chaoweihuang/miniconda3/envs/seamless-fairseq2/lib/python3.10/site-packages/torch/distributed/utils.py", line 142, in _sync_module_states
    _sync_params_and_buffers(
  File "/fsx-ust/chaoweihuang/miniconda3/envs/seamless-fairseq2/lib/python3.10/site-packages/torch/distributed/utils.py", line 160, in _sync_params_and_buffers
    dist._broadcast_coalesced(
RuntimeError: Input tensor data type is not supported for NCCL process group: ComplexFloat

A workaround for this is that we can set the _ddp_params_and_buffers_to_ignore attribute in the root module that's going to be wrapped with DDP since we don't really need to sync freqs across ranks:

model._ddp_params_and_buffers_to_ignore = [
    f"decoder.layers.{i}.self_attn.pos_encoder.freqs"
    for i in range(len(model.decoder.layers))
]

chaoweihuang avatar Oct 04 '23 20:10 chaoweihuang

Thanks for the report @chaoweihuang. We have a similar issue in gumbel softmax quantizer as well. As you already found out, the right solution would be to let DDP ignore those non-persistent buffers. Let me think of a clean way for that. The last time I checked, the DDP API for that was still experimental.

cbalioglu avatar Oct 06 '23 17:10 cbalioglu