Optimize kernel launch for triton2.2.0 and triton2.3.0
- triton 2.1.0 has best performance
- parse signature (in 2.2.0 and 2.3.0) cost a lot.
- 2.3.0 does not accept device and stream.
parse signature (in 2.2.0 and 2.3.0) cost a lot.
Hi @grimoire Does it mean the performance of pytorch bachend using triton 2.2.0 is much lower than triton 2.1.0?
@jjjjohnson yes, pytorch engine would await when kernel launch finish, so the CPU can be used by other works (in most case, the tokenizer). triton2.2.0+ parses the signature from a dict, which greatly enlarges kernel launch time.
@jjjjohnson yes, pytorch engine would await when kernel launch finish, so the CPU can be used by other works (in most case, the tokenizer). triton2.2.0+ parses the signature from a dict, which greatly enlarges kernel launch time.
So the triton 2.2.0 with decorator @wrap_jit_func, is the performance same as triton 2.1.0? @grimoire
@jjjjohnson I am afraid not. The performance would be better than origin 2.2.0 but still worse than 2.1.0. There is still work to do.
Hi @grimoire I cannot find triton2.2.0 or triton2.3.0 in triton github where can I find it?
They have release/2.x.x branch
Hi @grimoire PyTorch and Triton both support AMD. If I want to run LMDeploy PyTorch Engine on AMD GPU, what extra work do I need to do? Thanks.
@zhyncs There are device cast and stream synchronize in engine/engine.py, engine/model_agent.py and models/patch.py. Other modules should be device agnostic.
Hi @grimoire I am confused how the @wrap_jit_func reduces kernel launch time. The code in triton_utils.py is a bit hard to understand...
how the @wrap_jit_func reduces kernel launch time
triton will give a key to each compiled kernel, the key is composed of data type, value alignment, const_value and other meta info. That means you need to get all inputs, check their attribute (dtype, can be divide by 8, etc) and get the key at each kernel launch.
Before triton 2.2.0, The lunch function (as string) will be generated and execute with exec
https://github.com/openai/triton/blob/da40a1e984bf57c4708daf603eb427442025f99b/python/triton/runtime/jit.py#L400
argument parse would not cost much time for a pre-compiled function with fixed arguements.
After triton 2.2.0, the arguments will be passed as a dict and parsed with signature.bind
https://github.com/openai/triton/blob/0e7b97bd47fc4beb21ae960a516cd9a7ae9bc060/python/triton/runtime/jit.py#L426
This is more elegant than exec, but parsing arguments cost more CPU computation than the fixed function arguments. As the bottleneck of our pytorch engine is on host side, the pre-2.2.0 solution is better for us.
@wrap_jit_func would generate the lunch function like triton 2.1.0 and by pass the origin luncher.
signature.bind
@grimoire Thanks for you great explain! Looks like this issus is very relavant. BTW why do parsing arguments requires much CPU computation? My understanding is parsing arguments is very CPU light operation.
@jjjjohnson yes, it is CPU light operation. But since a transformer is stack of decode layers and each layer requires multiple custom kernels (attention/rmsnorm/repo...), the expense is hard to ignore.
Tested result on llama3
python benchmark/profile_throughput.py ./ShareGPT_V3_unfiltered_cleaned_split.json ./Meta-Llama-3-8B-Instruct --backend pytorch
triton==2.3.1
2024-06-18 13:23:10,680 - lmdeploy - INFO - build CacheEngine with config:CacheConfig(block_size=64, num_cpu_blocks=512, num_gpu_blocks=6100, window_size=-1, cache_max_entry_count=0.8, max_prefill_token_num=4096, enable_prefix_caching=False)
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 5000/5000 [05:41<00:00, 2.19s/it]
--------------------------------------------------
concurrency: 256
elapsed_time: 341.836s
first token latency(s)(min, max, ave): 1.838, 19.980, 3.325
per-token latency(s) percentile(50, 75, 95, 99): [0.046, 0.072, 0.295, 0.337]
number of prompt tokens: 1138200
number of completion tokens: 1006232
token throughput (completion token): 2943.609 token/s
token throughput (prompt + completion token): 6273.275 token/s
RPS (request per second): 14.627 req/s
RPM (request per minute): 877.613 req/min
triton==2.2.0
2024-06-18 14:01:40,598 - lmdeploy - INFO - build CacheEngine with config:CacheConfig(block_size=64, num_cpu_blocks=512, num_gpu_blocks=6166, window_size=-1, cache_max_entry_count=0.8, max_prefill_token_num=4096, enable_prefix_caching=False)
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 5000/5000 [05:13<00:00, 1.24it/s]
--------------------------------------------------
concurrency: 256
elapsed_time: 313.438s
first token latency(s)(min, max, ave): 1.753, 16.526, 2.997
per-token latency(s) percentile(50, 75, 95, 99): [0.042, 0.043, 0.279, 0.331]
number of prompt tokens: 1136185
number of completion tokens: 1003966
token throughput (completion token): 3203.080 token/s
token throughput (prompt + completion token): 6827.995 token/s
RPS (request per second): 15.952 req/s
RPM (request per minute): 957.128 req/min
--------------------------------------------------
triton==2.1.0
2024-06-18 14:23:46,816 - lmdeploy - INFO - build CacheEngine with config:CacheConfig(block_size=64, num_cpu_blocks=512, num_gpu_blocks=6166, window_size=-1, cache_max_entry_count=0.8, max_prefill_token_num=4096, enable_prefix_caching=False)
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 5000/5000 [05:17<00:00, 1.37it/s]
--------------------------------------------------
concurrency: 256
elapsed_time: 317.658s
first token latency(s)(min, max, ave): 1.187, 11.992, 2.835
per-token latency(s) percentile(50, 75, 95, 99): [0.041, 0.044, 0.292, 0.348]
number of prompt tokens: 1136185
number of completion tokens: 1003966
token throughput (completion token): 3160.524 token/s
token throughput (prompt + completion token): 6737.279 token/s
RPS (request per second): 15.740 req/s
RPM (request per minute): 944.412 req/min
--------------------------------------------------