TensorRT-LLM icon indicating copy to clipboard operation
TensorRT-LLM copied to clipboard

The system hangs when I use multiple GPUs to perform streaming inference

Open pansicheng opened this issue 1 year ago • 5 comments

System Info

  • GPU properties: 8 * NVIDIA GeForce RTX 4090
  • TensorRT-LLM branch: v0.7.1
  • NVIDIA Driver Version: 535.154.05
  • CUDA Version: 12.2
  • Container used: build from tensorrtllm_backend (https://github.com/triton-inference-server/tensorrtllm_backend/tree/v0.7.1?tab=readme-ov-file#option-3-build-via-docker)

Who can help?

@byshiue

Information

  • [X] The official example scripts
  • [ ] My own modified scripts

Tasks

  • [X] An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • [ ] My own task or dataset (give details below)

Reproduction

  1. build trt_engine
cd examples/qwen

python build.py \
  --hf_model_dir /path/to/qwen-72b-chat/ \
  --output_dir /path/to/Qwen/72B/trt_engines/int4_weight_only/8-gpu/ \
  --remove_input_padding \
  --use_gpt_attention_plugin float16 \
  --enable_context_fmha \
  --use_weight_only  --weight_only_precision int4 \
  --paged_kv_cache \
  --use_inflight_batching \
  --world_size 8 \
  --tp_size 8
  1. streaming inference
cd examples/qwen

mpirun -n 8 --allow-run-as-root \
python3 ../run.py \
  --input_text "hello" \
  --max_output_len=2048 \
  --tokenizer_dir /path/to/qwen-72b-chat/ \
  --engine_dir=/path/to/Qwen/72B/trt_engines/int4_weight_only/8-gpu/ \
  --streaming --streaming_interval 1 \

Expected behavior

The system performs streaming inference

actual behavior

The system hangs

additional notes

I discovered that during streaming inference within the system, only the process designated as 'rank 0' performs the iteration. (https://github.com/NVIDIA/TensorRT-LLM/blob/v0.7.1/examples/run.py#L355)

    if runtime_rank == 0:
        if args.streaming:
            for curr_outputs in throttle_generator(outputs,
                                                   args.streaming_interval):
                output_ids = curr_outputs['output_ids']
                sequence_lengths = curr_outputs['sequence_lengths']
                print_output(...)

Making all the processes perform the iteration seems to solve this problem.

    if args.streaming:
        if runtime_rank == 0:
            for curr_outputs in throttle_generator(outputs,
                                                   args.streaming_interval):
                output_ids = curr_outputs['output_ids']
                sequence_lengths = curr_outputs['sequence_lengths']
                print_output(...)
                torch.cuda.synchronize()
        else:
            for curr_outputs in throttle_generator(outputs,
                                                   args.streaming_interval):
                pass

pansicheng avatar Feb 05 '24 04:02 pansicheng

+1

noahnisbet avatar Feb 14 '24 14:02 noahnisbet

I encountered the same issue on the 4090, and the error message is as follows. Have you resolved it? image

PaulX1029 avatar Mar 30 '24 10:03 PaulX1029

I encountered the same issue on the 4090, and the error message is as follows. Have you resolved it? image

You can take a look at the pseudocode in the 'additional notes' of the issue description. I modified the code in /examples/run.py#L355 of version v0.7.1 to allow processes with a non-zero runtime_rank to also perform stream generation to solve this problem. Hope it helps

pansicheng avatar Apr 02 '24 07:04 pansicheng

I encountered the same issue in the non-streaming mode. Qwen-72B, 4gpus.

qfyinbd avatar May 24 '24 06:05 qfyinbd

Could you take a try on latest main branch and disable use_custom_all_reduce during building engine?

byshiue avatar May 27 '24 09:05 byshiue