[BUG] Context parallel gives NCCL error
Describe the bug
I am using the train_gpt3_175b_distributed.sh script to launch training on a single node with 4 A100 80GB GPUs. Training goes well if I use tensor parallel or pipeline parallel, but fails if I enable context parallel. The following is my script:
#!/bin/bash
# Runs the "175B" parameter model
export CUDA_DEVICE_MAX_CONNECTIONS=1
GPUS_PER_NODE=4
# Change for multinode config
MASTER_ADDR=localhost
MASTER_PORT=6000
NUM_NODES=1
NODE_RANK=0
WORLD_SIZE=$(($GPUS_PER_NODE*$NUM_NODES))
CHECKPOINT_PATH=$1 #<Specify path>
TENSORBOARD_LOGS_PATH=$2 #<Specify path>
VOCAB_FILE=$3 #<Specify path to file>/gpt2-vocab.json
MERGE_FILE=$4 #<Specify path to file>/gpt2-merges.txt
DATA_PATH=$5 #<Specify path and file prefix>_text_document
DISTRIBUTED_ARGS=(
--nproc_per_node $GPUS_PER_NODE
--nnodes $NUM_NODES
--master_addr $MASTER_ADDR
--master_port $MASTER_PORT
)
GPT_MODEL_ARGS=(
--num-layers 32
--hidden-size 2048
--ffn-hidden-size 8192
--num-attention-heads 16
--seq-length 2048
--max-position-embeddings 2048
)
TRAINING_ARGS=(
--micro-batch-size 1
--global-batch-size 1
--train-iters 500000
--weight-decay 0.1
--adam-beta1 0.9
--adam-beta2 0.95
--init-method-std 0.006
--clip-grad 1.0
--fp16
--lr 6.0e-5
--lr-decay-style cosine
--min-lr 6.0e-6
--lr-warmup-fraction .001
--lr-decay-iters 430000
--recompute-activations
)
MODEL_PARALLEL_ARGS=(
--tensor-model-parallel-size 1
--pipeline-model-parallel-size 1
--context-parallel-size 4
)
DATA_ARGS=(
--data-path "/temp_document"
--vocab-file "/gpt2-vocab.json"
--merge-file "/gpt2-merges.txt"
--split 949,50,1
)
EVAL_AND_LOGGING_ARGS=(
--log-interval 100
--save-interval 10000
--eval-interval 1000
--save "/temp/"
--load "/temp/"
--eval-iters 10
--tensorboard-dir "/Megatron-LM/examples/gpt3"
)
torchrun ${DISTRIBUTED_ARGS[@]} pretrain_gpt.py \
${GPT_MODEL_ARGS[@]} \
${TRAINING_ARGS[@]} \
${MODEL_PARALLEL_ARGS[@]} \
${DATA_ARGS[@]} \
${EVAL_AND_LOGGING_ARGS[@]}
The output log is:
building GPT model ...
> number of parameters on (tensor, pipeline) model parallel rank (0, 0): 1718685696
> number of parameters on (tensor, pipeline) model parallel rank (0, 0): 1718685696
> number of parameters on (tensor, pipeline) model parallel rank (0, 0): 1718685696
> number of parameters on (tensor, pipeline) model parallel rank (0, 0): 1718685696
WARNING: could not find the metadata file /temp/latest_checkpointed_iteration.txt
will not load any checkpoints and will start from random
/pytorch/torch/distributed/c10d_logger.py:83: FutureWarning: `torch.distributed._all_gather_base` is a private function and will be deprecated. Please use `torch.distributed.all_gather_into_tensor` instead.
return func(*args, **kwargs)
/pytorch/torch/distributed/c10d_logger.py:83: FutureWarning: `torch.distributed._all_gather_base` is a private function and will be deprecated. Please use `torch.distributed.all_gather_into_tensor` instead.
return func(*args, **kwargs)
/pytorch/torch/distributed/c10d_logger.py:83: FutureWarning: `torch.distributed._all_gather_base` is a private function and will be deprecated. Please use `torch.distributed.all_gather_into_tensor` instead.
return func(*args, **kwargs)
/pytorch/torch/distributed/c10d_logger.py:83: FutureWarning: `torch.distributed._all_gather_base` is a private function and will be deprecated. Please use `torch.distributed.all_gather_into_tensor` instead.
return func(*args, **kwargs)
(min, max) time across ranks (ms):
load-checkpoint ................................: (0.73, 0.78)
[after model, optimizer, and learning rate scheduler are built] datetime: 2024-09-19 15:56:43
> building train, validation, and test datasets ...
> datasets target sizes (minimum size):
train: 500000
validation: 5010
test: 10
> building train, validation, and test datasets for GPT ...
> finished creating GPT datasets ...
[after dataloaders are built] datetime: 2024-09-19 15:56:43
done with setup ...
training ...
(min, max) time across ranks (ms):
model-and-optimizer-setup ......................: (300.04, 309.25)
train/valid/test-data-iterators-setup ..........: (288.56, 334.41)
[before the start of training step] datetime: 2024-09-19 15:56:43
WARNING:megatron.core.utils:NCCL Error 5: invalid usage (run with NCCL_DEBUG=WARN for details)
['Traceback (most recent call last):\n', ' File "/Megatron-LM/pretrain_gpt.py", line 192, in forward_step\n output_tensor = model(tokens, position_ids, attention_mask,\n', ' File "/pytorch/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl\n return self._call_impl(*args, **kwargs)\n', ' File "/pytorch/torch/nn/modules/module.py", line 1747, in _call_impl\n return forward_call(*args, **kwargs)\n', ' File "/Megatron-LM/megatron/core/distributed/distributed_data_parallel.py", line 305, in forward\n return self.module(*inputs, **kwargs)\n', ' File "/pytorch/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl\n return self._call_impl(*args, **kwargs)\n', ' File "/pytorch/torch/nn/modules/module.py", line 1747, in _call_impl\n return forward_call(*args, **kwargs)\n', ' File "/Megatron-LM/megatron/legacy/model/module.py", line 189, in forward\n outputs = self.module(*inputs, **kwargs)\n', ' File "/pytorch/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl\n return self._call_impl(*args, **kwargs)\n', ' File "/pytorch/torch/nn/modules/module.py", line 1747, in _call_impl\n return forward_call(*args, **kwargs)\n', ' File "/Megatron-LM/megatron/core/models/gpt/gpt_model.py", line 217, in forward\n hidden_states = self.decoder(\n', ' File "/pytorch/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl\n return self._call_impl(*args, **kwargs)\n', ' File "/pytorch/torch/nn/modules/module.py", line 1747, in _call_impl\n return forward_call(*args, **kwargs)\n', ' File "/Megatron-LM/megatron/core/transformer/transformer_block.py", line 496, in forward\n hidden_states, context = layer(\n', ' File "/Megatron-LM/megatron/core/transformer/transformer_layer.py", line 377, in __call__\n return super(MegatronModule, self).__call__(*args, **kwargs)\n', ' File "/pytorch/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl\n return self._call_impl(*args, **kwargs)\n', ' File "/pytorch/torch/nn/modules/module.py", line 1747, in _call_impl\n return forward_call(*args, **kwargs)\n', ' File "/Megatron-LM/megatron/core/transformer/transformer_layer.py", line 281, in forward\n attention_output_with_bias = self.self_attention(\n', ' File "/pytorch/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl\n return self._call_impl(*args, **kwargs)\n', ' File "/pytorch/torch/nn/modules/module.py", line 1747, in _call_impl\n return forward_call(*args, **kwargs)\n', ' File "/Megatron-LM/megatron/core/transformer/attention.py", line 291, in forward\n core_attn_out = self._checkpointed_attention_forward(\n', ' File "/Megatron-LM/megatron/core/transformer/attention.py", line 143, in _checkpointed_attention_forward\n hidden_states = tensor_parallel.checkpoint(\n', ' File "/Megatron-LM/megatron/core/tensor_parallel/random.py", line 308, in checkpoint\n return CheckpointFunction.apply(function, distribute_saved_activations, *args)\n', ' File "/pytorch/torch/autograd/function.py", line 575, in apply\n return super().apply(*args, **kwargs) # type: ignore[misc]\n', ' File "/Megatron-LM/megatron/core/tensor_parallel/random.py", line 247, in forward\n outputs = run_function(*args)\n', ' File "/Megatron-LM/megatron/core/transformer/attention.py", line 130, in custom_forward\n output_ = self.core_attention(\n', ' File "/pytorch/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl\n return self._call_impl(*args, **kwargs)\n', ' File "/pytorch/torch/nn/modules/module.py", line 1747, in _call_impl\n return forward_call(*args, **kwargs)\n', ' File "/Megatron-LM/megatron/core/extensions/transformer_engine.py", line 589, in forward\n core_attn_out = super().forward(\n', ' File "/miniconda3/envs/megatron_lm/lib/python3.10/site-packages/transformer_engine/pytorch/attention.py", line 6854, in forward\n return self.flash_attention(\n', ' File "/pytorch/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl\n return self._call_impl(*args, **kwargs)\n', ' File "/pytorch/torch/nn/modules/module.py", line 1747, in _call_impl\n return forward_call(*args, **kwargs)\n', ' File "/miniconda3/envs/megatron_lm/lib/python3.10/site-packages/transformer_engine/pytorch/attention.py", line 4328, in forward\n output = attn_forward_func_with_cp(\n', ' File "/miniconda3/envs/megatron_lm/lib/python3.10/site-packages/transformer_engine/pytorch/attention.py", line 3379, in attn_forward_func_with_cp\n out = AttnFuncWithCPAndKVP2P.apply(\n', ' File "/pytorch/torch/autograd/function.py", line 575, in apply\n return super().apply(*args, **kwargs) # type: ignore[misc]\n', ' File "/miniconda3/envs/megatron_lm/lib/python3.10/site-packages/transformer_engine/pytorch/attention.py", line 1473, in forward\n send_recv_reqs[i % 2] = flash_attn_p2p_communicate(\n', ' File "/miniconda3/envs/megatron_lm/lib/python3.10/site-packages/transformer_engine/pytorch/attention.py", line 1249, in flash_attn_p2p_communicate\n send_op = torch.distributed.isend(send_tensor, send_dst, cp_group)\n', ' File "/pytorch/torch/distributed/distributed_c10d.py", line 2071, in isend\n return pg.send([tensor], dst, tag)\n', 'RuntimeError: NCCL Error 5: invalid usage (run with NCCL_DEBUG=WARN for details)\n']
[rank2]: Traceback (most recent call last):
[rank2]: File "/Megatron-LM/pretrain_gpt.py", line 264, in <module>
[rank2]: pretrain(
[rank2]: File "/Megatron-LM/megatron/training/training.py", line 355, in pretrain
[rank2]: iteration, num_floating_point_operations_so_far = train(
[rank2]: File "/Megatron-LM/megatron/training/training.py", line 1234, in train
[rank2]: train_step(forward_step_func,
[rank2]: File "/Megatron-LM/megatron/training/training.py", line 718, in train_step
[rank2]: losses_reduced = forward_backward_func(
[rank2]: File "/Megatron-LM/megatron/core/pipeline_parallel/schedules.py", line 468, in forward_backward_no_pipelining
[rank2]: output_tensor, num_tokens = forward_step(
[rank2]: File "/Megatron-LM/megatron/core/pipeline_parallel/schedules.py", line 273, in forward_step
[rank2]: output_tensor, loss_func = forward_step_func(data_iterator, model)
[rank2]: File "/Megatron-LM/pretrain_gpt.py", line 192, in forward_step
[rank2]: output_tensor = model(tokens, position_ids, attention_mask,
[rank2]: File "/pytorch/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank2]: return self._call_impl(*args, **kwargs)
[rank2]: File "/pytorch/torch/nn/modules/module.py", line 1747, in _call_impl
[rank2]: return forward_call(*args, **kwargs)
[rank2]: File "/Megatron-LM/megatron/core/distributed/distributed_data_parallel.py", line 305, in forward
[rank2]: return self.module(*inputs, **kwargs)
[rank2]: File "/pytorch/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank2]: return self._call_impl(*args, **kwargs)
[rank2]: File "/pytorch/torch/nn/modules/module.py", line 1747, in _call_impl
[rank2]: return forward_call(*args, **kwargs)
[rank2]: File "/Megatron-LM/megatron/legacy/model/module.py", line 189, in forward
[rank2]: outputs = self.module(*inputs, **kwargs)
[rank2]: File "/pytorch/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank2]: return self._call_impl(*args, **kwargs)
[rank2]: File "/pytorch/torch/nn/modules/module.py", line 1747, in _call_impl
[rank2]: return forward_call(*args, **kwargs)
[rank2]: File "/Megatron-LM/megatron/core/models/gpt/gpt_model.py", line 217, in forward
[rank2]: hidden_states = self.decoder(
[rank2]: File "/pytorch/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank2]: return self._call_impl(*args, **kwargs)
[rank2]: File "/pytorch/torch/nn/modules/module.py", line 1747, in _call_impl
[rank2]: return forward_call(*args, **kwargs)
[rank2]: File "/Megatron-LM/megatron/core/transformer/transformer_block.py", line 496, in forward
[rank2]: hidden_states, context = layer(
[rank2]: File "/Megatron-LM/megatron/core/transformer/transformer_layer.py", line 377, in __call__
[rank2]: return super(MegatronModule, self).__call__(*args, **kwargs)
[rank2]: File "/pytorch/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank2]: return self._call_impl(*args, **kwargs)
[rank2]: File "/pytorch/torch/nn/modules/module.py", line 1747, in _call_impl
[rank2]: return forward_call(*args, **kwargs)
[rank2]: File "/Megatron-LM/megatron/core/transformer/transformer_layer.py", line 281, in forward
[rank2]: attention_output_with_bias = self.self_attention(
[rank2]: File "/pytorch/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank2]: return self._call_impl(*args, **kwargs)
[rank2]: File "/pytorch/torch/nn/modules/module.py", line 1747, in _call_impl
[rank2]: return forward_call(*args, **kwargs)
[rank2]: File "/Megatron-LM/megatron/core/transformer/attention.py", line 291, in forward
[rank2]: core_attn_out = self._checkpointed_attention_forward(
[rank2]: File "/Megatron-LM/megatron/core/transformer/attention.py", line 143, in _checkpointed_attention_forward
[rank2]: hidden_states = tensor_parallel.checkpoint(
[rank2]: File "/Megatron-LM/megatron/core/tensor_parallel/random.py", line 308, in checkpoint
[rank2]: return CheckpointFunction.apply(function, distribute_saved_activations, *args)
[rank2]: File "/pytorch/torch/autograd/function.py", line 575, in apply
[rank2]: return super().apply(*args, **kwargs) # type: ignore[misc]
[rank2]: File "/Megatron-LM/megatron/core/tensor_parallel/random.py", line 247, in forward
[rank2]: outputs = run_function(*args)
[rank2]: File "/Megatron-LM/megatron/core/transformer/attention.py", line 130, in custom_forward
[rank2]: output_ = self.core_attention(
[rank2]: File "/pytorch/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank2]: return self._call_impl(*args, **kwargs)
[rank2]: File "/pytorch/torch/nn/modules/module.py", line 1747, in _call_impl
[rank2]: return forward_call(*args, **kwargs)
[rank2]: File "/Megatron-LM/megatron/core/extensions/transformer_engine.py", line 589, in forward
[rank2]: core_attn_out = super().forward(
[rank2]: File "/miniconda3/envs/megatron_lm/lib/python3.10/site-packages/transformer_engine/pytorch/attention.py", line 6854, in forward
[rank2]: return self.flash_attention(
[rank2]: File "/pytorch/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank2]: return self._call_impl(*args, **kwargs)
[rank2]: File "/pytorch/torch/nn/modules/module.py", line 1747, in _call_impl
[rank2]: return forward_call(*args, **kwargs)
[rank2]: File "/miniconda3/envs/megatron_lm/lib/python3.10/site-packages/transformer_engine/pytorch/attention.py", line 4328, in forward
[rank2]: output = attn_forward_func_with_cp(
[rank2]: File "/miniconda3/envs/megatron_lm/lib/python3.10/site-packages/transformer_engine/pytorch/attention.py", line 3379, in attn_forward_func_with_cp
[rank2]: out = AttnFuncWithCPAndKVP2P.apply(
[rank2]: File "/pytorch/torch/autograd/function.py", line 575, in apply
[rank2]: return super().apply(*args, **kwargs) # type: ignore[misc]
[rank2]: File "/miniconda3/envs/megatron_lm/lib/python3.10/site-packages/transformer_engine/pytorch/attention.py", line 1473, in forward
[rank2]: send_recv_reqs[i % 2] = flash_attn_p2p_communicate(
[rank2]: File "/miniconda3/envs/megatron_lm/lib/python3.10/site-packages/transformer_engine/pytorch/attention.py", line 1249, in flash_attn_p2p_communicate
[rank2]: send_op = torch.distributed.isend(send_tensor, send_dst, cp_group)
[rank2]: File "/pytorch/torch/distributed/distributed_c10d.py", line 2071, in isend
[rank2]: return pg.send([tensor], dst, tag)
[rank2]: RuntimeError: NCCL Error 5: invalid usage (run with NCCL_DEBUG=WARN for details)
To Reproduce Steps to reproduce the behavior. The easier it is to reproduce the faster it will get maintainer attention.
Expected behavior A clear and concise description of what you expected to happen.
Stack trace/logs If applicable, add the stack trace or logs from the time of the error.
Environment (please complete the following information):
- Megatron-LM commit ID: commit 6b35ca80e8baca6f357b97304979e6b1c9a31899
- PyTorch version: 2.6.0a0+git803ce50
- CUDA version: 12.3.0
- NCCL version: 2.21.5
- TransformerEngine: 1.10.0
- nvidia-smi topo -m:
Proposed fix If you have a proposal for how to fix the issue state it here or link to a PR.
Additional context Add any other context about the problem here.
Marking as stale. No activity in 60 days.
Same problem encountered.
The environment variable NVTE_BATCH_MHA_P2P_COMM needs to be set as 1, then this error will not occur. See the transformer_engine code here: https://github.com/NVIDIA/TransformerEngine/blob/303c6d16203b3cb01675f7adb7c21956f140e0ee/transformer_engine/pytorch/attention.py#L1869
Marking as stale. No activity in 60 days.
This issue was closed because it has been inactive for 7 days since being marked as stale.