fastertransformer_backend icon indicating copy to clipboard operation
fastertransformer_backend copied to clipboard

Multi-instance inference fails in (n-1)/n runs (where n is a number gpus/instances)

Open timofeev1995 opened this issue 2 years ago • 29 comments

Hello. Than you for your work and framework!

My goal is to host n instances of GPTJ-6B on N graphic cards. I want to have N instances with one model in each. My setup is 3x3090 (one more host has 5x3090, but everything else is similar), docker container built according the repository readme, and fastertransformer config looking like that (for 3gpu setup):

name: "fastertransformer"
backend: "fastertransformer"
default_model_filename: "gpt-j-6b"
max_batch_size: 8

model_transaction_policy {
  decoupled: False
}

input [
  {
    name: "input_ids"
    data_type: TYPE_UINT32
    dims: [ -1 ]
  },
  {
    name: "start_id"
    data_type: TYPE_UINT32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "end_id"
    data_type: TYPE_UINT32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "input_lengths"
    data_type: TYPE_UINT32
    dims: [ 1 ]
    reshape: { shape: [ ] }
  },
  {
    name: "request_output_len"
    data_type: TYPE_UINT32
    dims: [ -1 ]
  },
  {
    name: "runtime_top_k"
    data_type: TYPE_UINT32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "runtime_top_p"
    data_type: TYPE_FP32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "beam_search_diversity_rate"
    data_type: TYPE_FP32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "temperature"
    data_type: TYPE_FP32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "len_penalty"
    data_type: TYPE_FP32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "repetition_penalty"
    data_type: TYPE_FP32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "random_seed"
    data_type: TYPE_INT32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "is_return_log_probs"
    data_type: TYPE_BOOL
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "beam_width"
    data_type: TYPE_UINT32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "bad_words_list"
    data_type: TYPE_INT32
    dims: [ 2, -1 ]
    optional: true
  },
  {
    name: "stop_words_list"
    data_type: TYPE_INT32
    dims: [ 2, -1 ]
    optional: true
  }
]
output [
  {
    name: "output_ids"
    data_type: TYPE_UINT32
    dims: [ -1, -1 ]
  },
  {
    name: "sequence_length"
    data_type: TYPE_UINT32
    dims: [ -1 ]
  },
  {
    name: "cum_log_probs"
    data_type: TYPE_FP32
    dims: [ -1 ]
  },
  {
    name: "output_log_probs"
    data_type: TYPE_FP32
    dims: [ -1, -1 ]
  }
]
instance_group [
  {
    count: 3
    kind: KIND_CPU
  }
]
parameters {
  key: "tensor_para_size"
  value: {
    string_value: "1"
  }
}
parameters {
  key: "pipeline_para_size"
  value: {
    string_value: "1"
  }
}
parameters {
  key: "data_type"
  value: {
    string_value: "fp16"
  }
}
parameters {
  key: "model_type"
  value: {
    string_value: "GPT-J"
  }
}
parameters {
  key: "model_checkpoint_path"
  value: {
    string_value: "/workspace/models/j6b_ckpt/1-gpu"
  }
}
parameters {
  key: "enable_custom_all_reduce"
  value: {
    string_value: "0"
  }
}
dynamic_batching {
  max_queue_delay_microseconds: 300000
}

I start triton using command: CUDA_VISIBLE_DEVICES="0,1,2" /opt/tritonserver/bin/tritonserver --http-port 8282 --log-verbose 10 --model-repository=/workspace/triton-models

I face the issue which I can describe that way: I see that 3 instances running on 0-1-2 gpus. But when I request inference, I get correct generation result in 1/3 cases:

I1026 11:01:20.566547 93 libfastertransformer.cc:1647] model fastertransformer, instance fastertransformer_0_0, executing 1 requests
I1026 11:01:20.566598 93 libfastertransformer.cc:812] TRITONBACKEND_ModelExecute: Running fastertransformer_0_0 with 1 requests
I1026 11:01:20.566614 93 libfastertransformer.cc:886] get total batch_size = 1
I1026 11:01:20.566632 93 libfastertransformer.cc:1296] get input count = 14
I1026 11:01:20.566660 93 libfastertransformer.cc:1368] collect name: stop_words_list size: 8 bytes
I1026 11:01:20.566683 93 libfastertransformer.cc:1368] collect name: input_lengths size: 4 bytes
I1026 11:01:20.566701 93 libfastertransformer.cc:1368] collect name: request_output_len size: 4 bytes
I1026 11:01:20.566718 93 libfastertransformer.cc:1368] collect name: temperature size: 4 bytes
I1026 11:01:20.566735 93 libfastertransformer.cc:1368] collect name: random_seed size: 4 bytes
I1026 11:01:20.566751 93 libfastertransformer.cc:1368] collect name: bad_words_list size: 8 bytes
I1026 11:01:20.566766 93 libfastertransformer.cc:1368] collect name: runtime_top_k size: 4 bytes
I1026 11:01:20.566782 93 libfastertransformer.cc:1368] collect name: runtime_top_p size: 4 bytes
I1026 11:01:20.566799 93 libfastertransformer.cc:1368] collect name: input_ids size: 2048 bytes
I1026 11:01:20.566815 93 libfastertransformer.cc:1368] collect name: start_id size: 4 bytes
I1026 11:01:20.566831 93 libfastertransformer.cc:1368] collect name: end_id size: 4 bytes
I1026 11:01:20.566845 93 libfastertransformer.cc:1368] collect name: beam_width size: 4 bytes
I1026 11:01:20.566860 93 libfastertransformer.cc:1368] collect name: beam_search_diversity_rate size: 4 bytes
I1026 11:01:20.566876 93 libfastertransformer.cc:1368] collect name: repetition_penalty size: 4 bytes
I1026 11:01:20.566896 93 libfastertransformer.cc:1379] the data is in CPU
I1026 11:01:20.566908 93 libfastertransformer.cc:1386] the data is in CPU
I1026 11:01:20.566945 93 libfastertransformer.cc:1244] before ThreadForward 0
I1026 11:01:20.567101 93 libfastertransformer.cc:1252] after ThreadForward 0
I1026 11:01:20.567169 93 libfastertransformer.cc:1090] Start to forward
I1026 11:01:22.420238 93 libfastertransformer.cc:1098] Stop to forward
I1026 11:01:22.420410 93 libfastertransformer.cc:1411] Get output_tensors 0: output_ids
I1026 11:01:22.420460 93 libfastertransformer.cc:1421]     output_type: UINT32
I1026 11:01:22.420477 93 libfastertransformer.cc:1443]     output shape: [1, 1, 612]
I1026 11:01:22.420506 93 infer_response.cc:166] add response output: output: output_ids, type: UINT32, shape: [1,1,612]
I1026 11:01:22.420539 93 http_server.cc:1068] HTTP: unable to provide 'output_ids' in GPU, will use CPU
I1026 11:01:22.420560 93 http_server.cc:1088] HTTP using buffer for: 'output_ids', size: 2448, addr: 0x7fdf67d28ca0
I1026 11:01:22.420579 93 pinned_memory_manager.cc:161] pinned memory allocation: size 2448, addr 0x7fe3ea000090
I1026 11:01:22.420663 93 libfastertransformer.cc:1411] Get output_tensors 1: sequence_length
I1026 11:01:22.420676 93 libfastertransformer.cc:1421]     output_type: INT32
I1026 11:01:22.420688 93 libfastertransformer.cc:1443]     output shape: [1, 1]
I1026 11:01:22.420703 93 infer_response.cc:166] add response output: output: sequence_length, type: INT32, shape: [1,1]
I1026 11:01:22.420718 93 http_server.cc:1068] HTTP: unable to provide 'sequence_length' in GPU, will use CPU
I1026 11:01:22.420732 93 http_server.cc:1088] HTTP using buffer for: 'sequence_length', size: 4, addr: 0x7fd949e73600
I1026 11:01:22.420745 93 pinned_memory_manager.cc:161] pinned memory allocation: size 4, addr 0x7fe3ea000a30
I1026 11:01:22.420785 93 libfastertransformer.cc:1458] PERFORMED GPU copy: NO
I1026 11:01:22.420797 93 pinned_memory_manager.cc:190] pinned memory deallocation: addr 0x7fe3ea000090
I1026 11:01:22.420810 93 pinned_memory_manager.cc:190] pinned memory deallocation: addr 0x7fe3ea000a30
I1026 11:01:22.420829 93 libfastertransformer.cc:1001] get response size = 1
I1026 11:01:22.420902 93 http_server.cc:1140] HTTP release: size 2448, addr 0x7fdf67d28ca0
I1026 11:01:22.420917 93 http_server.cc:1140] HTTP release: size 4, addr 0x7fd949e73600
I1026 11:01:22.420930 93 libfastertransformer.cc:1016] response is sent

But in 2/3 cases (which are not routed to fastertransformer_0_0, but fastertransformer_0_1 / fastertransformer_0_2) I ger the following:

I1026 11:01:40.172315 93 libfastertransformer.cc:1647] model fastertransformer, instance fastertransformer_0_2, executing 1 requests
I1026 11:01:40.172367 93 libfastertransformer.cc:812] TRITONBACKEND_ModelExecute: Running fastertransformer_0_2 with 1 requests
I1026 11:01:40.172393 93 libfastertransformer.cc:886] get total batch_size = 1
I1026 11:01:40.172417 93 libfastertransformer.cc:1296] get input count = 14
I1026 11:01:40.172446 93 libfastertransformer.cc:1368] collect name: stop_words_list size: 8 bytes
I1026 11:01:40.172473 93 libfastertransformer.cc:1368] collect name: input_lengths size: 4 bytes
I1026 11:01:40.172494 93 libfastertransformer.cc:1368] collect name: request_output_len size: 4 bytes
I1026 11:01:40.172515 93 libfastertransformer.cc:1368] collect name: temperature size: 4 bytes
I1026 11:01:40.172538 93 libfastertransformer.cc:1368] collect name: random_seed size: 4 bytes
I1026 11:01:40.172554 93 libfastertransformer.cc:1368] collect name: bad_words_list size: 8 bytes
I1026 11:01:40.172569 93 libfastertransformer.cc:1368] collect name: runtime_top_k size: 4 bytes
I1026 11:01:40.172584 93 libfastertransformer.cc:1368] collect name: runtime_top_p size: 4 bytes
I1026 11:01:40.172607 93 libfastertransformer.cc:1368] collect name: input_ids size: 2048 bytes
I1026 11:01:40.172624 93 libfastertransformer.cc:1368] collect name: start_id size: 4 bytes
I1026 11:01:40.172639 93 libfastertransformer.cc:1368] collect name: end_id size: 4 bytes
I1026 11:01:40.172659 93 libfastertransformer.cc:1368] collect name: beam_width size: 4 bytes
I1026 11:01:40.172681 93 libfastertransformer.cc:1368] collect name: beam_search_diversity_rate size: 4 bytes
I1026 11:01:40.172703 93 libfastertransformer.cc:1368] collect name: repetition_penalty size: 4 bytes
I1026 11:01:40.172726 93 libfastertransformer.cc:1379] the data is in CPU
I1026 11:01:40.172743 93 libfastertransformer.cc:1386] the data is in CPU
I1026 11:01:40.172785 93 libfastertransformer.cc:1244] before ThreadForward 2
I1026 11:01:40.172948 93 libfastertransformer.cc:1252] after ThreadForward 2
I1026 11:01:40.173059 93 libfastertransformer.cc:1090] Start to forward
I1026 11:01:42.007168 93 libfastertransformer.cc:1098] Stop to forward
I1026 11:01:42.007373 93 libfastertransformer.cc:1411] Get output_tensors 0: output_ids
I1026 11:01:42.007431 93 libfastertransformer.cc:1421]     output_type: UINT32
I1026 11:01:42.007449 93 libfastertransformer.cc:1443]     output shape: [1, 1, 612]
I1026 11:01:42.007468 93 infer_response.cc:166] add response output: output: output_ids, type: UINT32, shape: [1,1,612]
I1026 11:01:42.007500 93 http_server.cc:1068] HTTP: unable to provide 'output_ids' in GPU, will use CPU
I1026 11:01:42.007523 93 http_server.cc:1088] HTTP using buffer for: 'output_ids', size: 2448, addr: 0x7fe2d40a94b0
I1026 11:01:42.007542 93 pinned_memory_manager.cc:161] pinned memory allocation: size 2448, addr 0x7fe3ea000090
I1026 11:01:42.007616 93 http_server.cc:1140] HTTP release: size 2448, addr 0x7fe2d40a94b0
I1026 11:01:42.007637 93 libfastertransformer.cc:1411] Get output_tensors 1: sequence_length
I1026 11:01:42.007648 93 libfastertransformer.cc:1421]     output_type: INT32
I1026 11:01:42.007661 93 libfastertransformer.cc:1443]     output shape: [1, 1]
I1026 11:01:42.007673 93 libfastertransformer.cc:1458] PERFORMED GPU copy: NO
I1026 11:01:42.007686 93 pinned_memory_manager.cc:190] pinned memory deallocation: addr 0x7fe3ea000090
I1026 11:01:42.007701 93 libfastertransformer.cc:1001] get response size = 1
W1026 11:01:42.007713 93 libfastertransformer.cc:1019] response is nullptr

And for triton-http-client (python one) accordingly: tritonclient.utils.InferenceServerException: pinned buffer: failed to perform CUDA copy: invalid argument

Is there any error in my setup and pipeline? What is correct way to setup N models on N gpus?

Thank you in advance!

timofeev1995 avatar Oct 26 '22 12:10 timofeev1995

@timofeev1995 I cannot reproduce this on A40 x 4 machines. Can you try to build the latest ft_triton image and try it on other machines (like A100, and A40)? And also it would be great if you can share the config.ini which is generated when you convert the model.

PerkzZheng avatar Oct 27 '22 09:10 PerkzZheng

@PerkzZheng, thank you for reply!

I have the following config.ini:

model_name = gptj-6B
head_num = 16
size_per_head = 256
inter_size = 16384
num_layer = 28
rotary_embedding = 64
vocab_size = 50400
start_id = 50256
end_id = 50256
weight_data_type = fp32

Can you try to build the latest ft_triton image and try it on other machines (like A100, and A40)?

Do you mean there can be some issue with connection between gpus? I have following setup:

nvidia-smi topo -m
	GPU0	GPU1	GPU2	GPU3	GPU4	CPU Affinity	NUMA Affinity
GPU0	 X 	PIX	NODE	NODE	NODE	0-15,32-47	0
GPU1	PIX	 X 	NODE	NODE	NODE	0-15,32-47	0
GPU2	NODE	NODE	 X 	PIX	PIX	0-15,32-47	0
GPU3	NODE	NODE	PIX	 X 	PIX	0-15,32-47	0
GPU4	NODE	NODE	PIX	PIX	 X 	0-15,32-47	0

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

Thank you!

timofeev1995 avatar Oct 27 '22 10:10 timofeev1995

Do you mean there can be some issue with connection between gpus?

No, there is no communication between GPUs. I am just trying to understand if this happens only in certain circumstances. Also, please add --ipc=host if you are using docker containers.

PerkzZheng avatar Oct 27 '22 11:10 PerkzZheng

Unfortunatelly with --ipc=host used with docker run result (nullptr for each instance except main fastertransformer_0_0) is the same.

timofeev1995 avatar Oct 27 '22 11:10 timofeev1995

okay. Have you tried to build the image with the latest triton? I will also try the gptj config.ini you shared.

PerkzZheng avatar Oct 27 '22 11:10 PerkzZheng

im going to get last tag for triton server and try it with GCP 2xA100 instance. Thank you!

timofeev1995 avatar Oct 27 '22 11:10 timofeev1995

Got it. Everything works as expected. Could we try to suggest what is the problem in my initial setup? Thank you!

timofeev1995 avatar Oct 27 '22 14:10 timofeev1995

I am not quite sure. Have you tried to run the inference with the latest image on 3 RTX3090s?

PerkzZheng avatar Oct 28 '22 01:10 PerkzZheng

I tried to use the last triton-server container (22.09) with 3090, get the same issue.

timofeev1995 avatar Oct 31 '22 16:10 timofeev1995

what you could try to help us eliminate some potential factors:

  1. try to launch model instances on GPU1 and GPU2 only by setting visible devices
  2. build in debug mode

PerkzZheng avatar Nov 01 '22 06:11 PerkzZheng

I'm facing the same issue on 2xT4

lakshaykc avatar Feb 21 '23 05:02 lakshaykc

@lakshaykc Hi, I have the same issue in two T4. Have you solved this problem.

DunZhang avatar Feb 28 '23 06:02 DunZhang

No, I haven't been able to fix that issue yet.

lakshaykc avatar Feb 28 '23 06:02 lakshaykc

@DunZhang @lakshaykc does it work if two model instances are on the same GPU ?

PerkzZheng avatar Feb 28 '23 06:02 PerkzZheng

Yes, that works. Even though there is no communication between the two model instances or GPUs, it seems there is something related to cross GPU communication. But I have no idea what that could be. Two separate instances of GPU1 and GPU2 as @PerkzZheng mentioned above works, but obviously that is not a clean solution.

lakshaykc avatar Feb 28 '23 06:02 lakshaykc

I cannot reproduce on 2xT4 machines, so maybe you can try to add the codes to ibfastertransformer.cc#L1793:

    cudaPointerAttributes attributes;
    cudaPointerGetAttributes (&attributes,output_buffer);
    LOG_MESSAGE(
        TRITONSERVER_LOG_VERBOSE, (std::string("    Memory_pointer_type: ") +
                                   std::to_string(attributes.type))
                                   .c_str());

and see if the pointers are device pointers (2).

PerkzZheng avatar Feb 28 '23 07:02 PerkzZheng

Also, add cudaDeviceSynchronize() at that point to see if anything changes. It looks there are some race conditions that pointers haven't been allocated before passing to cudaMemcpyAsync.

PerkzZheng avatar Feb 28 '23 07:02 PerkzZheng

@lakshaykc @DunZhang have a try with this branch and see if errors still exist. Let me know if you find anything helpful.

PerkzZheng avatar Feb 28 '23 08:02 PerkzZheng

Ok, thanks, I'll try and let you know what a happens.

lakshaykc avatar Feb 28 '23 08:02 lakshaykc

Hi @PerkzZheng ! I'm also facing the same issue and my setup is similiar to @timofeev1995 with several 3090. My topo is

GPU0	GPU1	GPU2	GPU3	GPU4	CPU Affinity	NUMA Affinity
GPU0	 X 	PIX	PIX	NODE	NODE	0-15,32-47	0
GPU1	PIX	 X 	PIX	NODE	NODE	0-15,32-47	0
GPU2	PIX	PIX	 X 	NODE	NODE	0-15,32-47	0
GPU3	NODE	NODE	NODE	 X 	PIX	0-15,32-47	0
GPU4	NODE	NODE	NODE	PIX	 X 	0-15,32-47	0

I've been debugging this behaviour for some time and I found out that in my scenario the problem is in this cudaMemcpyAsync. While I was in the debug mode I found out that switching the gpu by using cudaSetDevice fixed the problem and I implemented a crutch which would cycle through the GPUs untill it's managed to copy the data. I hope that maybe this is going to help you to understand this bug better.

I'll also check your bugfix asap and come back with the results.

hawkeoni avatar Mar 01 '23 15:03 hawkeoni

I've tested your fix and unfortunately it did not resolve the problem for me. I double checked that the code indeed contains your fix in the docker image but I still get the same errors. I hope my previous comment may help you.

Just in case I'm getting the pinned buffer: failed to perform CUDA copy: invalid argument error

hawkeoni avatar Mar 01 '23 15:03 hawkeoni

Thanks. @hawkeoni You mean doing cudaSetDevice for each model instance ? as far as I know, cudaMemcpyAsync doesn't need a setDevice before it as the stream has the device id info. Anyway, thanks. And I will push a quick fix that you guys can validate it.

PerkzZheng avatar Mar 02 '23 02:03 PerkzZheng

we do have that for each model instance.. Do you mean doing this before responder.processTensor ?

PerkzZheng avatar Mar 02 '23 02:03 PerkzZheng

it might be the case that the model instance constructor and the forwarder are using different threads (which are also different device contexts).

PerkzZheng avatar Mar 02 '23 02:03 PerkzZheng

Hi, @PerkzZheng , I'm apologize I couldn't answer you sooner. I've once again tested your fix on setup with 2 to 5 gpus and it didn't work (I had my doubts, because I originally tried it on 5 gpus).

Thanks. @hawkeoni You mean doing cudaSetDevice for each model instance ? as far as I know, cudaMemcpyAsync doesn't need a setDevice before it as the stream has the device id info. Anyway, thanks. And I will push a quick fix that you guys can validate it.

I also believe that cuda copy commands do not require set device. Moreover I've written a simple script to test on my machine which just copies from different gpus to the cpu and it works with any setup of active device, so I'm not entirely sure what my fix does, but it works.

we do have that for each model instance.. Do you mean doing this before responder.processTensor ?

I've been debugging this problem and I found out that the copy doesn't work even at this moment of time.

we do have that for each model instance.. Do you mean doing this before responder.processTensor ?

I'll try to give you stack trace the best way I can:

  1. ProcessRequest
  2. ReadOutputTensors
  3. ProcessTensor
  4. FlushPendingPinned
  5. CopyBuffer (Pinned)
  6. cudaMemcpyAsync

This cudaMemcpyAsync fails each time it's not on the first gpu with cudaErrorInvalidValue . I've made a small fix where I use a loop to iterate over all gpus and try to do memcopy until I succeed or run out of gpus and it works. As you've said before and as I've checked on my setup - cuda copies work correctly without setting a device and there is no DeviceToDevice communication, so I'm at a loss at this moment. Hope this helps.

hawkeoni avatar Mar 03 '23 16:03 hawkeoni

Can you try this branch fix/multi_instance?

byshiue avatar Mar 31 '23 07:03 byshiue

I have replaced current branch with fix/multi_instance. The problem should be solved for now (My setup is 3xRTX3090). Notes: you should adjust dynamic batching queue time to get full utilization of gpus.

huyphan168 avatar May 01 '23 14:05 huyphan168

@byshiue Hi! I'm sorry it took me so long to answer, but I finally got to check your fix and it works on my setup! (RTX3090) Is it possible to merge it into the main branch?

hawkeoni avatar May 06 '23 11:05 hawkeoni

@byshiue @PerkzZheng Hi! The fix from this comment works, are there any plans to merge it?

hawkeoni avatar Jul 10 '23 07:07 hawkeoni