vllm
vllm copied to clipboard
Add half rmsnorm kernel
I found that there is a kenel for writing subsequent optimizations in rmsnorm, and I tried to write a half-precision kernel for rms. Below is the comparison data, I tested it on A100 80G, range 10000
benchmark
| num_tokens | hidden_size | Elapsed time (base) | Elapsed time (opt) |
|---|---|---|---|
| 2048 | 64 | 1.470 seconds | 1.594 seconds |
| 2048 | 768 | 0.163 seconds | 0.066 seconds |
| 2048 | 1024 | 0.171 seconds | 0.076 seconds |
| 2048 | 5120 | 0.531 seconds | 0.280 seconds |
| 4096 | 64 | 0.116 seconds | 0.090 seconds |
| 4096 | 768 | 0.256 seconds | 0.073 seconds |
| 4096 | 1024 | 0.311 seconds | 0.084 seconds |
| 4096 | 5120 | 0.991 seconds | 0.497 seconds |
| 10192 | 64 | 0.208 seconds | 0.158 seconds |
| 10192 | 768 | 0.707 seconds | 0.178 seconds |
| 10192 | 1024 | 0.851 seconds | 0.278 seconds |
| 10192 | 5120 | 2.383 seconds | 1.164 seconds |
analyze
Except token=2048&&hidden=64 in the case, there can be a good speedup ratio, because in some cases the time is too small, and there may be fluctuations when you reproduce. The worse in the case of token=2048&&hidden=64 should be due to the fact that the number of threads started at this time is too small, but this situation should not exist in the real scene.
Since I used pytorch to register the kernel for the first time, the current method of registering the kernel may not be elegant. Doing benchmark can only comment out the half-precision kernel, and then recompile to test the time-consuming. If you have a better method, you can tell me I will modify it.
@sleepcoo Awesome! Thanks for your contribution! Before I get into review, could you double-check the new kernel produces correct outputs? When I tested it out, it didn't match our reference implementation.
$ python tests/kernels/test_layernorm.py
Testing RMS kernel with dtype=torch.float16, num_tokens=7, hidden_size=13
Custom kernel output: tensor([[-9.7609e-04, -3.6955e-05, 7.8440e-04, -2.2471e-04, 9.1970e-05,
-4.5037e-04, -3.4380e-04, 1.5125e-03, 9.8801e-04, 1.1921e-07,
6.8140e-04, -7.1526e-06, 5.1618e-05],
[-1.7557e-03, 6.4802e-04, 2.0659e-04, 1.3149e-04, 3.1948e-05,
1.8585e-04, -2.8157e-04, 9.7156e-05, 1.0767e-03, -1.7681e-03,
3.4094e-04, 7.8106e-04, 1.8895e-05],
[ 7.9870e-04, 3.7408e-04, -6.6161e-05, -5.2691e-04, -1.5068e-03,
-1.7691e-04, -4.9639e-04, -1.5974e-05, -1.9038e-04, 3.7837e-04,
1.8263e-04, 1.1005e-03, 1.0719e-03],
[-2.9087e-04, -1.0366e-03, -1.0133e-05, -1.7273e-04, 4.2558e-04,
-1.2118e-04, -9.0957e-05, 1.5087e-03, 3.1114e-05, -1.7595e-03,
4.4346e-05, 1.6057e-04, -1.3852e-04],
[-2.4557e-04, -4.5848e-04, 1.3590e-05, -1.9002e-04, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00]], device='cuda:0',
dtype=torch.float16)
Reference output: tensor([[-1.1072e-03, -4.1962e-05, 8.8978e-04, -2.5511e-04, 1.0437e-04,
-5.1117e-04, -3.9005e-04, 1.7157e-03, -1.0691e-03, -8.3447e-07,
5.4407e-04, -1.1623e-05, 1.9526e-04],
[-2.3575e-03, 3.5405e-05, 2.0337e-04, 2.3007e-05, -7.6830e-05,
1.3483e-04, -8.1921e-04, -3.5286e-04, -8.4686e-04, 4.2200e-04,
1.4520e-04, 3.8004e-04, -2.9397e-04],
[ 8.2016e-04, 6.2704e-05, 2.4962e-04, -1.3065e-04, 2.0885e-04,
-1.3328e-04, -5.4169e-04, -2.9898e-04, 2.3282e-04, -5.6314e-04,
-6.0558e-04, 4.8757e-04, -6.1846e-04],
[-4.5466e-04, -5.6326e-05, -1.6999e-04, -5.7459e-05, -3.1781e-04,
6.1369e-04, -1.2141e-04, 1.4067e-03, -3.4094e-05, 5.5599e-04,
4.2844e-04, 6.2644e-05, 2.8181e-04],
[ 1.1368e-03, -2.2769e-05, 8.4639e-06, -4.2379e-05, 1.2100e-05,
-6.5851e-04, 1.7996e-03, 1.1816e-03, 1.1883e-03, -1.8001e-04,
-4.8351e-04, 4.5156e-04, -1.5652e-04],
[ 1.6296e-04, -7.1526e-07, 1.2875e-03, 3.0470e-04, 1.0806e-04,
-1.2960e-03, -1.5748e-04, -2.6870e-04, -5.6219e-04, -5.8889e-04,
6.4313e-05, -1.5581e-04, 4.9067e-04],
[-5.7602e-04, 1.0252e-05, -1.2989e-03, -2.8872e-04, 3.3796e-05,
6.0439e-05, 7.3671e-04, -4.2677e-04, 2.2531e-04, 7.4768e-04,
1.5032e-04, -1.0544e-04, -6.8283e-04]], device='cuda:0',
dtype=torch.float16)
@sleepcoo Awesome! Thanks for your contribution! Before I get into review, could you double-check the new kernel produces correct outputs? When I tested it out, it didn't match our reference implementation.
$ python tests/kernels/test_layernorm.py Testing RMS kernel with dtype=torch.float16, num_tokens=7, hidden_size=13 Custom kernel output: tensor([[-9.7609e-04, -3.6955e-05, 7.8440e-04, -2.2471e-04, 9.1970e-05, -4.5037e-04, -3.4380e-04, 1.5125e-03, 9.8801e-04, 1.1921e-07, 6.8140e-04, -7.1526e-06, 5.1618e-05], [-1.7557e-03, 6.4802e-04, 2.0659e-04, 1.3149e-04, 3.1948e-05, 1.8585e-04, -2.8157e-04, 9.7156e-05, 1.0767e-03, -1.7681e-03, 3.4094e-04, 7.8106e-04, 1.8895e-05], [ 7.9870e-04, 3.7408e-04, -6.6161e-05, -5.2691e-04, -1.5068e-03, -1.7691e-04, -4.9639e-04, -1.5974e-05, -1.9038e-04, 3.7837e-04, 1.8263e-04, 1.1005e-03, 1.0719e-03], [-2.9087e-04, -1.0366e-03, -1.0133e-05, -1.7273e-04, 4.2558e-04, -1.2118e-04, -9.0957e-05, 1.5087e-03, 3.1114e-05, -1.7595e-03, 4.4346e-05, 1.6057e-04, -1.3852e-04], [-2.4557e-04, -4.5848e-04, 1.3590e-05, -1.9002e-04, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]], device='cuda:0', dtype=torch.float16) Reference output: tensor([[-1.1072e-03, -4.1962e-05, 8.8978e-04, -2.5511e-04, 1.0437e-04, -5.1117e-04, -3.9005e-04, 1.7157e-03, -1.0691e-03, -8.3447e-07, 5.4407e-04, -1.1623e-05, 1.9526e-04], [-2.3575e-03, 3.5405e-05, 2.0337e-04, 2.3007e-05, -7.6830e-05, 1.3483e-04, -8.1921e-04, -3.5286e-04, -8.4686e-04, 4.2200e-04, 1.4520e-04, 3.8004e-04, -2.9397e-04], [ 8.2016e-04, 6.2704e-05, 2.4962e-04, -1.3065e-04, 2.0885e-04, -1.3328e-04, -5.4169e-04, -2.9898e-04, 2.3282e-04, -5.6314e-04, -6.0558e-04, 4.8757e-04, -6.1846e-04], [-4.5466e-04, -5.6326e-05, -1.6999e-04, -5.7459e-05, -3.1781e-04, 6.1369e-04, -1.2141e-04, 1.4067e-03, -3.4094e-05, 5.5599e-04, 4.2844e-04, 6.2644e-05, 2.8181e-04], [ 1.1368e-03, -2.2769e-05, 8.4639e-06, -4.2379e-05, 1.2100e-05, -6.5851e-04, 1.7996e-03, 1.1816e-03, 1.1883e-03, -1.8001e-04, -4.8351e-04, 4.5156e-04, -1.5652e-04], [ 1.6296e-04, -7.1526e-07, 1.2875e-03, 3.0470e-04, 1.0806e-04, -1.2960e-03, -1.5748e-04, -2.6870e-04, -5.6219e-04, -5.8889e-04, 6.4313e-05, -1.5581e-04, 4.9067e-04], [-5.7602e-04, 1.0252e-05, -1.2989e-03, -2.8872e-04, 3.3796e-05, 6.0439e-05, 7.3671e-04, -4.2677e-04, 2.2531e-04, 7.4768e-04, 1.5032e-04, -1.0544e-04, -6.8283e-04]], device='cuda:0', dtype=torch.float16)
I'm really sorry, because the provided kernel calculates 8 hidden_size dimensions each time, the kernel will have an error when hidden_size% 8! = 0, I will fix the bug and submit it , please help to review again, now when hidden_size% = 8 will fall back to the original kernel,and is there any progress in the optimization of sampler, I can help if needed.Also, what is the configuration of your code format? I use vscode google style, and there will be more conflicts with your code format.
Hi @sleepcoo, Is the bug fixed now? We will add the code format checker later. 🙏 Could you wrap up this PR first?
Hi @sleepcoo, Is the bug fixed now? We will add the code format checker later. 🙏 Could you wrap up this PR first?
I fixed it, you can review it when you have time, I also completed the fusion of softmax and div in sampler, which is twice the improvement compared to the original, I will submit another pr after you provide the code formatting tool
Hi @sleepcoo, Is the bug fixed now? We will add the code format checker later. 🙏 Could you wrap up this PR first?
I fixed it, you can review it when you have time, I also completed the fusion of softmax and div in sampler, which is twice the improvement compared to the original, I will submit another pr after you provide the code formatting tool
I suggest use OneFlow's RMSNorm Implementation: https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/cuda/rms_norm.cuh, when hiddensize is small, it use thread register to cache the input data. It can also use SharedMemory to cache data.
Hi @sleepcoo, Is the bug fixed now? We will add the code format checker later. 🙏 Could you wrap up this PR first?
I fixed it, you can review it when you have time, I also completed the fusion of softmax and div in sampler, which is twice the improvement compared to the original, I will submit another pr after you provide the code formatting tool
I suggest use OneFlow's RMSNorm Implementation: https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/core/cuda/rms_norm.cuh, when hiddensize is small, it use thread register to cache the input data. It can also use SharedMemory to cache data.
I also think that oneflow's codes implement better, but at present, I think the vllm library lacks some tools, such as vectorize reading, template reduce, etc. (the current one is not efficient and easy to use, such as reduce does not support max). If you directly copy oneflow, the library will gradually become bloated, so only minimal changes have been made. I also referred to the oneflow implementation for a softmax fusion oneflow fuse_softmax~ If you are interested, we can try to commit some util first and make changes to the kernel.
Hi @sleepcoo, thanks for submitting the PR and sorry for the delay in my review. I left some comments on the code style.
BTW, could you update your PR branch so that I can run our code formatter? Thanks.
Hi @sleepcoo, thanks for submitting the PR and sorry for the delay in my review. I left some comments on the code style.
BTW, could you update your PR branch so that I can run our code formatter? Thanks.
I modified the code style and submitted it. Can you provide some formater tools later, like .clang-format?It's too painful to manually adjust the code format:sob:
Recently, I learned vllm and found typename Vec < scalar_t, pack_size > :: Type. I found that the half kernel implemented before is not very elegant. This submission only involves the kernel. Use "attention_utils" in vllm to rewrite the rms norm and support all types of vllm. At present, there are some differences between the code style and other kernels in the library. Can you provide some configuration files to help me format?@WoosukKwon
Closing the PR since It is pretty old and we'd like to stick to the current RMSNorm implementation that upcasts the data type to FP32 during computation.