lmdeploy icon indicating copy to clipboard operation
lmdeploy copied to clipboard

Optimize kernel launch for triton2.2.0 and triton2.3.0

Open grimoire opened this issue 1 year ago • 12 comments

  • triton 2.1.0 has best performance
  • parse signature (in 2.2.0 and 2.3.0) cost a lot.
  • 2.3.0 does not accept device and stream.

grimoire avatar Apr 25 '24 12:04 grimoire

parse signature (in 2.2.0 and 2.3.0) cost a lot. Hi @grimoire Does it mean the performance of pytorch bachend using triton 2.2.0 is much lower than triton 2.1.0?

jjjjohnson avatar Apr 29 '24 03:04 jjjjohnson

@jjjjohnson yes, pytorch engine would await when kernel launch finish, so the CPU can be used by other works (in most case, the tokenizer). triton2.2.0+ parses the signature from a dict, which greatly enlarges kernel launch time.

grimoire avatar Apr 29 '24 03:04 grimoire

@jjjjohnson yes, pytorch engine would await when kernel launch finish, so the CPU can be used by other works (in most case, the tokenizer). triton2.2.0+ parses the signature from a dict, which greatly enlarges kernel launch time.

So the triton 2.2.0 with decorator @wrap_jit_func, is the performance same as triton 2.1.0? @grimoire

jjjjohnson avatar Apr 29 '24 03:04 jjjjohnson

@jjjjohnson I am afraid not. The performance would be better than origin 2.2.0 but still worse than 2.1.0. There is still work to do.

grimoire avatar Apr 29 '24 03:04 grimoire

Hi @grimoire I cannot find triton2.2.0 or triton2.3.0 in triton github where can I find it?

jjjjohnson avatar Apr 30 '24 06:04 jjjjohnson

They have release/2.x.x branch

grimoire avatar Apr 30 '24 07:04 grimoire

Hi @grimoire PyTorch and Triton both support AMD. If I want to run LMDeploy PyTorch Engine on AMD GPU, what extra work do I need to do? Thanks.

zhyncs avatar Apr 30 '24 07:04 zhyncs

@zhyncs There are device cast and stream synchronize in engine/engine.py, engine/model_agent.py and models/patch.py. Other modules should be device agnostic.

grimoire avatar Apr 30 '24 08:04 grimoire

Hi @grimoire I am confused how the @wrap_jit_func reduces kernel launch time. The code in triton_utils.py is a bit hard to understand...

jjjjohnson avatar May 06 '24 03:05 jjjjohnson

how the @wrap_jit_func reduces kernel launch time

triton will give a key to each compiled kernel, the key is composed of data type, value alignment, const_value and other meta info. That means you need to get all inputs, check their attribute (dtype, can be divide by 8, etc) and get the key at each kernel launch.

Before triton 2.2.0, The lunch function (as string) will be generated and execute with exec https://github.com/openai/triton/blob/da40a1e984bf57c4708daf603eb427442025f99b/python/triton/runtime/jit.py#L400 argument parse would not cost much time for a pre-compiled function with fixed arguements.

After triton 2.2.0, the arguments will be passed as a dict and parsed with signature.bind https://github.com/openai/triton/blob/0e7b97bd47fc4beb21ae960a516cd9a7ae9bc060/python/triton/runtime/jit.py#L426 This is more elegant than exec, but parsing arguments cost more CPU computation than the fixed function arguments. As the bottleneck of our pytorch engine is on host side, the pre-2.2.0 solution is better for us.

@wrap_jit_func would generate the lunch function like triton 2.1.0 and by pass the origin luncher.

grimoire avatar May 06 '24 04:05 grimoire

signature.bind

@grimoire Thanks for you great explain! Looks like this issus is very relavant. BTW why do parsing arguments requires much CPU computation? My understanding is parsing arguments is very CPU light operation.

jjjjohnson avatar May 06 '24 06:05 jjjjohnson

@jjjjohnson yes, it is CPU light operation. But since a transformer is stack of decode layers and each layer requires multiple custom kernels (attention/rmsnorm/repo...), the expense is hard to ignore.

grimoire avatar May 06 '24 11:05 grimoire

Tested result on llama3

python benchmark/profile_throughput.py ./ShareGPT_V3_unfiltered_cleaned_split.json ./Meta-Llama-3-8B-Instruct --backend pytorch

triton==2.3.1

2024-06-18 13:23:10,680 - lmdeploy - INFO - build CacheEngine with config:CacheConfig(block_size=64, num_cpu_blocks=512, num_gpu_blocks=6100, window_size=-1, cache_max_entry_count=0.8, max_prefill_token_num=4096, enable_prefix_caching=False)
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 5000/5000 [05:41<00:00,  2.19s/it]
--------------------------------------------------
concurrency: 256
elapsed_time: 341.836s

first token latency(s)(min, max, ave): 1.838, 19.980, 3.325
per-token latency(s) percentile(50, 75, 95, 99): [0.046, 0.072, 0.295, 0.337]

number of prompt tokens: 1138200
number of completion tokens: 1006232
token throughput (completion token): 2943.609 token/s
token throughput (prompt + completion token): 6273.275 token/s
RPS (request per second): 14.627 req/s
RPM (request per minute): 877.613 req/min

triton==2.2.0

2024-06-18 14:01:40,598 - lmdeploy - INFO - build CacheEngine with config:CacheConfig(block_size=64, num_cpu_blocks=512, num_gpu_blocks=6166, window_size=-1, cache_max_entry_count=0.8, max_prefill_token_num=4096, enable_prefix_caching=False)
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 5000/5000 [05:13<00:00,  1.24it/s]
--------------------------------------------------
concurrency: 256
elapsed_time: 313.438s

first token latency(s)(min, max, ave): 1.753, 16.526, 2.997
per-token latency(s) percentile(50, 75, 95, 99): [0.042, 0.043, 0.279, 0.331]

number of prompt tokens: 1136185
number of completion tokens: 1003966
token throughput (completion token): 3203.080 token/s
token throughput (prompt + completion token): 6827.995 token/s
RPS (request per second): 15.952 req/s
RPM (request per minute): 957.128 req/min
--------------------------------------------------

triton==2.1.0

2024-06-18 14:23:46,816 - lmdeploy - INFO - build CacheEngine with config:CacheConfig(block_size=64, num_cpu_blocks=512, num_gpu_blocks=6166, window_size=-1, cache_max_entry_count=0.8, max_prefill_token_num=4096, enable_prefix_caching=False)
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 5000/5000 [05:17<00:00,  1.37it/s]
--------------------------------------------------
concurrency: 256
elapsed_time: 317.658s

first token latency(s)(min, max, ave): 1.187, 11.992, 2.835
per-token latency(s) percentile(50, 75, 95, 99): [0.041, 0.044, 0.292, 0.348]

number of prompt tokens: 1136185
number of completion tokens: 1003966
token throughput (completion token): 3160.524 token/s
token throughput (prompt + completion token): 6737.279 token/s
RPS (request per second): 15.740 req/s
RPM (request per minute): 944.412 req/min
--------------------------------------------------

RunningLeon avatar Jun 18 '24 10:06 RunningLeon