vllm
vllm copied to clipboard
[Kernel] Refactor FP8 kv-cache with NVIDIA float8_e4m3 support
The first PR for #4532.
Task list:
- [x] Add NVIDIA e4m3.
- [x] Refactor cache_kernel.cu.
- [x] Unit tests for reshape and cache.
- [x] Refactor attention_kernel.cu.
- [x] Unit tests for attention.
- [x] Refactor AMD.
- [x] Compatibility with FP8 disabled.
BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE
PR Checklist (Click to Expand)
Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.
PR Title and Classification
Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:
[Bugfix]for bug fixes.[CI/Build]for build or continuous integration improvements.[Doc]for documentation fixes and improvements.[Model]for adding a new model or improving an existing model. Model name should appear in the title.[Frontend]For changes on the vLLM frontend (e.g., OpenAI API server,LLMclass, etc.)[Kernel]for changes affecting CUDA kernels or other compute kernels.[Core]for changes in the core vLLM logic (e.g.,LLMEngine,AsyncLLMEngine,Scheduler, etc.)[Hardware][Vendor]for hardware-specific changes. Vendor name should appear in the prefix (e.g.,[Hardware][AMD]).[Misc]for PRs that do not fit the above categories. Please use this sparingly.
Note: If the PR spans more than one category, please include all relevant prefixes.
Code Quality
The PR need to meet the following code quality standards:
- We adhere to Google Python style guide and Google C++ style guide.
- Pass all linter checks. Please use
format.shto format your code. - The code need to be well-documented to ensure future contributors can easily understand the code.
- Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
- Please add documentation to
docs/source/if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.
Notes for Large Changes
Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.
What to Expect for the Reviews
The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:
- After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
- After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
- After the review, the reviewer will put an
action-requiredlabel on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR. - Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.
Thank You
Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!
Per offline discussion, this PR only includes backend refactoring for FP8 kv-cache related kernels and utilities. A follow-up PR will then cover the scaling factor loading. Thus, this PR is ready for review.
cc @pcmoritz @robertgshaw2-neuralmagic @HaiShaw @WoosukKwon
It is a bummer that github doesn't render the diff between the old and new nvidia quant_utils.cuh -- for ease of reviewing, here is the diff:
(base) pcmoritz@pcmoritz-DQ44HV60WX /tmp % diff quant_utils_old.cuh quant_utils_new.cuh
2a3,6
> #include "../../../attention/attention_dtypes.h"
> #include "../../../attention/dtype_bfloat16.cuh"
> #include "../../../attention/dtype_float16.cuh"
> #include "../../../attention/dtype_float32.cuh"
4d7
< #include <stdint.h>
5a9
> #include <stdint.h>
7,10d10
< #include "../../attention/attention_dtypes.h"
< #include "../../attention/dtype_float32.cuh"
< #include "../../attention/dtype_float16.cuh"
< #include "../../attention/dtype_bfloat16.cuh"
12d11
<
14,15d12
< #ifdef ENABLE_FP8_E5M2
< namespace fp8_e5m2_unscaled {
17,20c14,20
< template<typename Tout, typename Tin>
< __inline__ __device__ Tout vec_conversion(const Tin& x)
< {
< return x;
---
> namespace fp8 {
> #ifdef ENABLE_FP8
>
> template <typename Tout, typename Tin>
> __inline__ __device__ Tout
> vec_conversion(const Tin &x, const __nv_fp8_interpretation_t fp8_type) {
> return x;
24,28c24,28
< template<>
< __inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(const uint8_t& a)
< {
< __half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2);
< return res.x;
---
> template <>
> __inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(
> const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) {
> __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
> return res.x;
32,42c32,42
< template<>
< __inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(const uint16_t& a)
< {
< union {
< uint16_t u16[2];
< uint32_t u32;
< } tmp;
< __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, __NV_E5M2);
< tmp.u16[0] = res.x;
< tmp.u16[1] = res.y;
< return tmp.u32;
---
> template <>
> __inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(
> const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
> union {
> uint16_t u16[2];
> uint32_t u32;
> } tmp;
> __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type);
> tmp.u16[0] = res.x;
> tmp.u16[1] = res.y;
> return tmp.u32;
46,55c46,56
< template<>
< __inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a)
< {
< union {
< uint2 u32x2;
< uint32_t u32[2];
< } tmp;
< tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
< tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
< return tmp.u32x2;
---
> template <>
> __inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(
> const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
> union {
> uint2 u32x2;
> uint32_t u32[2];
> } tmp;
> tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a, fp8_type);
> tmp.u32[1] =
> vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), fp8_type);
> return tmp.u32x2;
59,68c60,69
< template<>
< __inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a)
< {
< union {
< uint4 u64x2;
< uint2 u64[2];
< } tmp;
< tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
< tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
< return tmp.u64x2;
---
> template <>
> __inline__ __device__ uint4 vec_conversion<uint4, uint2>(
> const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
> union {
> uint4 u64x2;
> uint2 u64[2];
> } tmp;
> tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x, fp8_type);
> tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y, fp8_type);
> return tmp.u64x2;
72,80c73,81
< template<>
< __inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a)
< {
< // Note there is no direct convert function from fp8 to bf16.
< // fp8 -> half
< __half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2);
< // half -> float -> bf16
< float tmp = half_to_float(res.x);
< return __float2bfloat16(tmp);
---
> template <>
> __inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(
> const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) {
> // Note there is no direct convert function from fp8 to bf16.
> // fp8 -> half
> __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
> // half -> float -> bf16
> float tmp = half_to_float(res.x);
> return __float2bfloat16(tmp);
84,90c85,91
< template<>
< __inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a)
< {
< __nv_bfloat162 res;
< res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
< res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
< return res;
---
> template <>
> __inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(
> const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
> __nv_bfloat162 res;
> res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, fp8_type);
> res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), fp8_type);
> return res;
94,100c95,102
< template<>
< __inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a)
< {
< bf16_4_t res;
< res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
< res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
< return res;
---
> template <>
> __inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(
> const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
> bf16_4_t res;
> res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, fp8_type);
> res.y =
> vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), fp8_type);
> return res;
104,115c106,117
< template<>
< __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a)
< {
< bf16_4_t tmp1, tmp2;
< tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
< tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
< bf16_8_t res;
< res.x = tmp1.x;
< res.y = tmp1.y;
< res.z = tmp2.x;
< res.w = tmp2.y;
< return res;
---
> template <>
> __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(
> const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
> bf16_4_t tmp1, tmp2;
> tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x, fp8_type);
> tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y, fp8_type);
> bf16_8_t res;
> res.x = tmp1.x;
> res.y = tmp1.y;
> res.z = tmp2.x;
> res.w = tmp2.y;
> return res;
119,125c121,128
< template<>
< __inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a)
< {
< // fp8 -> half
< uint16_t tmp = vec_conversion<uint16_t, uint8_t>(a);
< // half -> float
< return half_to_float(tmp);
---
> template <>
> __inline__ __device__ float
> vec_conversion<float, uint8_t>(const uint8_t &a,
> const __nv_fp8_interpretation_t fp8_type) {
> // fp8 -> half
> uint16_t tmp = vec_conversion<uint16_t, uint8_t>(a, fp8_type);
> // half -> float
> return half_to_float(tmp);
129,135c132,138
< template<>
< __inline__ __device__ float2 vec_conversion<float2, uint16_t>(const uint16_t& a)
< {
< // fp8x2 -> half2
< uint32_t tmp = vec_conversion<uint32_t, uint16_t>(a);
< // half2 -> float2
< return half2_to_float2(tmp);
---
> template <>
> __inline__ __device__ float2 vec_conversion<float2, uint16_t>(
> const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
> // fp8x2 -> half2
> uint32_t tmp = vec_conversion<uint32_t, uint16_t>(a, fp8_type);
> // half2 -> float2
> return half2_to_float2(tmp);
139,145c142,148
< template<>
< __inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(const uint32_t& a)
< {
< Float4_ res;
< res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
< res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
< return res;
---
> template <>
> __inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(
> const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
> Float4_ res;
> res.x = vec_conversion<float2, uint16_t>((uint16_t)a, fp8_type);
> res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), fp8_type);
> return res;
149,160c152,163
< template<>
< __inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a)
< {
< Float4_ tmp1, tmp2;
< tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
< tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
< Float8_ res;
< res.x = tmp1.x;
< res.y = tmp1.y;
< res.z = tmp2.x;
< res.w = tmp2.y;
< return res;
---
> template <>
> __inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(
> const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
> Float4_ tmp1, tmp2;
> tmp1 = vec_conversion<Float4_, uint32_t>(a.x, fp8_type);
> tmp2 = vec_conversion<Float4_, uint32_t>(a.y, fp8_type);
> Float8_ res;
> res.x = tmp1.x;
> res.y = tmp1.y;
> res.z = tmp2.x;
> res.w = tmp2.y;
> return res;
163d165
<
165,171c167,174
< template<>
< __inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(const uint16_t& a)
< {
< __half_raw tmp;
< tmp.x = a;
< __nv_fp8_storage_t res = __nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, __NV_E5M2);
< return (uint8_t)res;
---
> template <>
> __inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(
> const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
> __half_raw tmp;
> tmp.x = a;
> __nv_fp8_storage_t res =
> __nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, fp8_type);
> return (uint8_t)res;
175,177c178,180
< template<>
< __inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a)
< {
---
> template <>
> __inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(
> const __nv_bfloat16 &a, const __nv_fp8_interpretation_t fp8_type) {
179c182
< assert(false);
---
> assert(false);
181,182c184,186
< __nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(__nv_bfloat16_raw(a), __NV_SATFINITE, __NV_E5M2);
< return (uint8_t)res;
---
> __nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(
> __nv_bfloat16_raw(a), __NV_SATFINITE, fp8_type);
> return (uint8_t)res;
187,191c191,195
< template<>
< __inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a)
< {
< __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, __NV_E5M2);
< return (uint8_t)res;
---
> template <>
> __inline__ __device__ uint8_t vec_conversion<uint8_t, float>(
> const float &a, const __nv_fp8_interpretation_t fp8_type) {
> __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, fp8_type);
> return (uint8_t)res;
195,200c199,204
< template<>
< __inline__ __device__ float4 vec_conversion<float4, uint32_t>(const uint32_t& a)
< {
< Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
< float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
< return res;
---
> template <>
> __inline__ __device__ float4 vec_conversion<float4, uint32_t>(
> const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
> Float4_ tmp = vec_conversion<Float4_, uint32_t>(a, fp8_type);
> float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
> return res;
202a207,213
> template <>
> __inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(
> const float2 &a, const __nv_fp8_interpretation_t fp8_type) {
> union {
> half2 float16;
> uint32_t uint32;
> };
204,210c215,217
< template<>
< __inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(const float2& a)
< {
< union {
< half2 float16;
< uint32_t uint32;
< };
---
> float16 = __float22half2_rn(a);
> return uint32;
> }
212,213c219,232
< float16 = __float22half2_rn(a);
< return uint32;
---
> template <>
> __inline__ __device__ uint2 vec_conversion<uint2, Float4_>(
> const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
> uint2 b;
> float2 val;
> val.x = a.x.x;
> val.y = a.x.y;
> b.x = vec_conversion<uint32_t, float2>(val, fp8_type);
>
> val.x = a.y.x;
> val.y = a.y.y;
> b.y = vec_conversion<uint32_t, float2>(val, fp8_type);
>
> return b;
216,223c235,244
< template<>
< __inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a)
< {
< uint2 b;
< float2 val;
< val.x = a.x.x;
< val.y = a.x.y;
< b.x = vec_conversion<uint32_t, float2>(val);
---
> template <>
> __inline__ __device__ float4 vec_conversion<float4, Float4_>(
> const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
> float4 b;
> b.x = a.x.x;
> b.y = a.x.y;
> b.z = a.y.x;
> b.w = a.y.y;
> return b;
> }
225,227c246,255
< val.x = a.y.x;
< val.y = a.y.y;
< b.y = vec_conversion<uint32_t, float2>(val);
---
> template <>
> __inline__ __device__ uint4 vec_conversion<uint4, Float8_>(
> const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) {
> uint4 b;
> b.x = vec_conversion<uint32_t, float2>(a.x, fp8_type);
> b.y = vec_conversion<uint32_t, float2>(a.y, fp8_type);
> b.z = vec_conversion<uint32_t, float2>(a.z, fp8_type);
> b.w = vec_conversion<uint32_t, float2>(a.w, fp8_type);
> return b;
> }
229c257,262
< return b;
---
> template <>
> __inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(
> const float2 &a, const __nv_fp8_interpretation_t fp8_type) {
> __nv_bfloat162 b;
> from_float(b, a);
> return b;
232,240c265,270
< template<>
< __inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a)
< {
< float4 b;
< b.x = a.x.x;
< b.y = a.x.y;
< b.z = a.y.x;
< b.w = a.y.y;
< return b;
---
> template <>
> __inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(
> const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
> bf16_4_t b;
> from_float(b, a);
> return b;
243,251c273,278
< template<>
< __inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a)
< {
< uint4 b;
< b.x = vec_conversion<uint32_t, float2>(a.x);
< b.y = vec_conversion<uint32_t, float2>(a.y);
< b.z = vec_conversion<uint32_t, float2>(a.z);
< b.w = vec_conversion<uint32_t, float2>(a.w);
< return b;
---
> template <>
> __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(
> const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) {
> bf16_8_t b;
> from_float(b, a);
> return b;
254,258c281,290
< template<>
< __inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2 &a) {
< __nv_bfloat162 b;
< from_float(b, a);
< return b;
---
> /* Scaled and vectorized conversions, for data exchange between high and low
> precision domains Convention of the scale in API, e.g: FP8_data =
> Quantization( High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8
> Dequant(FP8) * scale => HP
> */
>
> template <typename Tout, typename Tin>
> __inline__ __device__ Tout scaled_vec_conversion(
> const Tin &x, const float scale, const __nv_fp8_interpretation_t fp8_type) {
> return x;
261,265c293,299
< template<>
< __inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(const Float4_ &a) {
< bf16_4_t b;
< from_float(b, a);
< return b;
---
> // fp8 -> half
> template <>
> __inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, uint8_t>(
> const uint8_t &a, const float scale,
> const __nv_fp8_interpretation_t fp8_type) {
> __half_raw tmp = __nv_cvt_fp8_to_halfraw(a, fp8_type);
> return float_to_half(half_to_float(tmp.x) * scale);
268,272c302,314
< template<>
< __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(const Float8_ &a) {
< bf16_8_t b;
< from_float(b, a);
< return b;
---
> // fp8x2 -> half2
> template <>
> __inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(
> const uint16_t &a, const float scale,
> const __nv_fp8_interpretation_t fp8_type) {
> union {
> uint16_t u16[2];
> uint32_t u32;
> } tmp;
> __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type);
> tmp.u16[0] = float_to_half(half_to_float(res.x) * scale);
> tmp.u16[1] = float_to_half(half_to_float(res.y) * scale);
> return tmp.u32;
275,276c317,576
< } // namespace fp8_e5m2_unscaled
< #endif // ENABLE_FP8_E5M2
---
> // fp8x4 -> half2x2
> template <>
> __inline__ __device__ uint2 scaled_vec_conversion<uint2, uint32_t>(
> const uint32_t &a, const float scale,
> const __nv_fp8_interpretation_t fp8_type) {
> union {
> uint2 u32x2;
> uint32_t u32[2];
> } tmp;
> tmp.u32[0] =
> scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale, fp8_type);
> tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U),
> scale, fp8_type);
> return tmp.u32x2;
> }
>
> // fp8x8 -> half2x4
> template <>
> __inline__ __device__ uint4
> scaled_vec_conversion<uint4, uint2>(const uint2 &a, const float scale,
> const __nv_fp8_interpretation_t fp8_type) {
> union {
> uint4 u64x2;
> uint2 u64[2];
> } tmp;
> tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale, fp8_type);
> tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale, fp8_type);
> return tmp.u64x2;
> }
>
> // fp8 -> __nv_bfloat16
> template <>
> __inline__ __device__ __nv_bfloat16
> scaled_vec_conversion<__nv_bfloat16, uint8_t>(
> const uint8_t &a, const float scale,
> const __nv_fp8_interpretation_t fp8_type) {
> // Note there is no direct convert function from fp8 to bf16.
> // fp8 -> half
> __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
> // half -> float -> bf16
> float tmp = half_to_float(res.x);
> return __float2bfloat16(tmp * scale);
> }
>
> // fp8x2 -> __nv_bfloat162
> template <>
> __inline__ __device__ __nv_bfloat162
> scaled_vec_conversion<__nv_bfloat162, uint16_t>(
> const uint16_t &a, const float scale,
> const __nv_fp8_interpretation_t fp8_type) {
> __nv_bfloat162 res;
> res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale,
> fp8_type);
> res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U),
> scale, fp8_type);
> return res;
> }
>
> // fp8x4 -> bf16_4_t
> template <>
> __inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
> const uint32_t &a, const float scale,
> const __nv_fp8_interpretation_t fp8_type) {
> bf16_4_t res;
> res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale,
> fp8_type);
> res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
> scale, fp8_type);
> return res;
> }
>
> // fp8x8 -> bf16_8_t
> template <>
> __inline__ __device__ bf16_8_t scaled_vec_conversion<bf16_8_t, uint2>(
> const uint2 &a, const float scale,
> const __nv_fp8_interpretation_t fp8_type) {
> bf16_4_t tmp1, tmp2;
> tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale, fp8_type);
> tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale, fp8_type);
> bf16_8_t res;
> res.x = tmp1.x;
> res.y = tmp1.y;
> res.z = tmp2.x;
> res.w = tmp2.y;
> return res;
> }
>
> // fp8 -> float
> template <>
> __inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
> const uint8_t &a, const float scale,
> const __nv_fp8_interpretation_t fp8_type) {
>
> // fp8 -> half
> uint16_t tmp = vec_conversion<uint16_t, uint8_t>(a, fp8_type);
> // half -> float
> return half_to_float(tmp) * scale;
> }
>
> // fp8x2 -> float2
> template <>
> __inline__ __device__ float2 scaled_vec_conversion<float2, uint16_t>(
> const uint16_t &a, const float scale,
> const __nv_fp8_interpretation_t fp8_type) {
> // fp8x2 -> half2
> uint32_t tmp = scaled_vec_conversion<uint32_t, uint16_t>(a, scale, fp8_type);
> // half2 -> float2
> return half2_to_float2(tmp);
> }
>
> // fp8x4 -> float4
> template <>
> __inline__ __device__ Float4_ scaled_vec_conversion<Float4_, uint32_t>(
> const uint32_t &a, const float scale,
> const __nv_fp8_interpretation_t fp8_type) {
> Float4_ res;
> res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale, fp8_type);
> res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale,
> fp8_type);
> return res;
> }
>
> // fp8x8 -> float8
> template <>
> __inline__ __device__ Float8_ scaled_vec_conversion<Float8_, uint2>(
> const uint2 &a, const float scale,
> const __nv_fp8_interpretation_t fp8_type) {
> Float4_ tmp1, tmp2;
> tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale, fp8_type);
> tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale, fp8_type);
> Float8_ res;
> res.x = tmp1.x;
> res.y = tmp1.y;
> res.z = tmp2.x;
> res.w = tmp2.y;
> return res;
> }
>
> // half -> fp8
> template <>
> __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, uint16_t>(
> const uint16_t &a, const float scale,
> const __nv_fp8_interpretation_t fp8_type) {
> __nv_fp8_storage_t res =
> __nv_cvt_float_to_fp8(half_to_float(a) / scale, __NV_SATFINITE, fp8_type);
> return (uint8_t)res;
> }
>
> // bf16 -> fp8
> template <>
> __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
> const __nv_bfloat16 &a, const float scale,
> const __nv_fp8_interpretation_t fp8_type) {
> #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
> assert(false);
> #else
> __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(__bfloat162float(a) / scale,
> __NV_SATFINITE, fp8_type);
> return (uint8_t)res;
> #endif
> }
>
> // float -> fp8
> template <>
> __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, float>(
> const float &a, const float scale,
> const __nv_fp8_interpretation_t fp8_type) {
> __nv_fp8_storage_t res =
> __nv_cvt_float_to_fp8(a / scale, __NV_SATFINITE, fp8_type);
> return (uint8_t)res;
> }
>
> // fp8x4 -> float4
> template <>
> __inline__ __device__ float4 scaled_vec_conversion<float4, uint32_t>(
> const uint32_t &a, const float scale,
> const __nv_fp8_interpretation_t fp8_type) {
> Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale, fp8_type);
> float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
> return res;
> }
> #endif // ENABLE_FP8
>
> template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
> __inline__ __device__ Tout convert(const Tin &x) {
> switch (kv_dt) {
> #ifdef ENABLE_FP8
> case Fp8KVCacheDataType::kAuto:
> // When the type is auto, Tin should be able to be converted to
> // Tout directly. Thus, the corresponding vec_conversion function
> // should ignore the last argument (e.g. __NV_E4M3).
> case Fp8KVCacheDataType::kFp8E4m3:
> return vec_conversion<Tout, Tin>(x, __NV_E4M3);
> case Fp8KVCacheDataType::kFp8E5m2:
> return vec_conversion<Tout, Tin>(x, __NV_E5M2);
> #endif
> default:
> assert(false);
> }
> }
>
> template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
> __inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) {
> switch (kv_dt) {
> #ifdef ENABLE_FP8
> case Fp8KVCacheDataType::kAuto:
> // When the type is auto, Tin should be able to be converted to
> // Tout directly. Thus, the corresponding vec_conversion function
> // should ignore the last argument (e.g. __NV_E4M3).
> case Fp8KVCacheDataType::kFp8E4m3:
> return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E4M3);
> case Fp8KVCacheDataType::kFp8E5m2:
> return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E5M2);
> #endif
> default:
> assert(false);
> }
> }
>
> // The following macro is used to dispatch the conversion function based on the
> // data type of the key and value cache. The FN is a macro that calls a function
> // with template<typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>.
> #define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
> if (KV_DTYPE == "auto") { \
> if (SRC_DTYPE == at::ScalarType::Float) { \
> FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \
> } else if (SRC_DTYPE == at::ScalarType::Half) { \
> FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \
> } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
> FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \
> } else { \
> TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
> } \
> } else { \
> if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
> if (SRC_DTYPE == at::ScalarType::Float) { \
> FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4m3); \
> } else if (SRC_DTYPE == at::ScalarType::Half) { \
> FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4m3); \
> } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
> FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4m3); \
> } else { \
> TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
> } \
> } else if (KV_DTYPE == "fp8_e5m2") { \
> if (SRC_DTYPE == at::ScalarType::Float) { \
> FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5m2); \
> } else if (SRC_DTYPE == at::ScalarType::Half) { \
> FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5m2); \
> } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
> FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5m2); \
> } else { \
> TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
> } \
> } else { \
> TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
> } \
> }
>
> } // namespace fp8
Did you investigate the performance impact of passing __nv_fp8_interpretation_t around at runtime? Have you considered making the format a template parameter of the vec_conversion and related functions (e.g. by reusing Fp8KVCacheDataType)?
Did you investigate the performance impact of passing
__nv_fp8_interpretation_taround at runtime? Have you considered making the format a template parameter of thevec_conversionand related functions (e.g. by reusing Fp8KVCacheDataType)?
Good question. It would be tedious to put this type to template, because we have roughly 30 overloaded functions. Since C++ doesn't allow partial specialized template, we have to manually duplicate them to 60 functions to cover both formats...
Why don't we test if there is a performance overhead (probably the compiler is already smart enough to optimize that -- it should be since the argument is constant in https://github.com/vllm-project/vllm/pull/4535/files#diff-97c4751eafe4ec7333fe2f140e29c84ea054f43d17d4286cc8c4e69a095d09aaR502 and similar for scaled_convert.
If the performance is ok, we can go forward with this.
Thanks a lot for cleaning this up @comaniac ❤️ This code was not pretty and now it is much nicer!
The only thing I'm not a fan of is
{nvidia, amd}/quant_utils.cuh. If anybody has ideas how to do that better, that would be very much appreciated!
I'll verify the performance. For naming, another way I could think of is {cuda,rocm/hip}/quant_utils.cuh, but I'm open to any proposal.
It is not about naming, more about having all these special cases and little conversion utilities :)
Why don't we test if there is a performance overhead (probably the compiler is already smart enough to optimize that -- it should be since the argument is constant in https://github.com/vllm-project/vllm/pull/4535/files#diff-97c4751eafe4ec7333fe2f140e29c84ea054f43d17d4286cc8c4e69a095d09aaR502 and similar for scaled_convert.
I benchmarked on L4 GPU and the latency difference is within 1-2% which should be acceptable.
@HaiShaw @AdrianAbeyta I keep seeing the following error when building this PR with ROCm. It seems like the same attention_generic.cuh headers in different places are included twice. Since I don't get this problem on nvcc, do you have any clue about how to resolve this for ROCm? Thanks.
#20 37.37 [3/13] Building HIP object CMakeFiles/_C.dir/csrc/attention/attention_kernels.hip.o
--
| #20 37.37 FAILED: CMakeFiles/_C.dir/csrc/attention/attention_kernels.hip.o
| #20 37.37 /opt/rocm/llvm/bin/clang++ -DTORCH_EXTENSION_NAME=_C -DUSE_C10D_GLOO -DUSE_C10D_NCCL -DUSE_DISTRIBUTED -DUSE_PROF_API=1 -DUSE_RPC -DUSE_TENSORPIPE -D_C_EXPORTS -D__HIP_PLATFORM_AMD__=1 -D__HIP_ROCclr__=1 -I/vllm-workspace/csrc -isystem /opt/conda/envs/py_3.9/include/python3.9 -isystem /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/include -isystem /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/include/torch/csrc/api/include -isystem /opt/rocm-6.0.0/include/hiprand -O2 -g -DNDEBUG -std=gnu++17 --offload-arch=gfx906 --offload-arch=gfx908 --offload-arch=gfx90a --offload-arch=gfx1030 --offload-arch=gfx1100 --offload-arch=gfx940 --offload-arch=gfx941 --offload-arch=gfx942 -fPIC -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -DUSE_ROCM -DENABLE_FP8 -U__HIP_NO_HALF_CONVERSIONS__ -U__HIP_NO_HALF_OPERATORS__ -fno-gpu-rdc -mllvm -amdgpu-early-inline-all=true -mllvm -amdgpu-function-calls=false -D_GLIBCXX_USE_CXX11_ABI=1 -DTORCH_HIP_VERSION=600 -Wno-shift-count-negative -Wno-shift-count-overflow -Wno-duplicate-decl-specifier -DCAFFE2_USE_MIOPEN -DTHRUST_DEVICE_SYSTEM=THRUST_DEVICE_SYSTEM_HIP -std=c++17 -MD -MT CMakeFiles/_C.dir/csrc/attention/attention_kernels.hip.o -MF CMakeFiles/_C.dir/csrc/attention/attention_kernels.hip.o.d -o CMakeFiles/_C.dir/csrc/attention/attention_kernels.hip.o -x hip -c /vllm-workspace/build/temp.linux-x86_64-cpython-39/csrc/attention/attention_kernels.hip
| #20 37.37 In file included from /vllm-workspace/build/temp.linux-x86_64-cpython-39/csrc/attention/attention_kernels.hip:31:
| #20 37.37 In file included from /vllm-workspace/csrc/quantization/fp8/amd/quant_utils.cuh:8:
| #20 37.37 In file included from /vllm-workspace/csrc/quantization/fp8/amd/../../../attention/attention_dtypes.h:3:
| #20 37.37 /vllm-workspace/csrc/quantization/fp8/amd/../../../attention/attention_generic.cuh:26:8: error: redefinition of 'Vec'
| #20 37.37 struct Vec {};
| #20 37.37 ^
| #20 37.37 /vllm-workspace/build/temp.linux-x86_64-cpython-39/csrc/attention/attention_generic.cuh:26:8: note: previous definition is here
| #20 37.37 struct Vec {};
| #20 37.37 ^
@HaiShaw @AdrianAbeyta I keep seeing the following error when building this PR with ROCm. It seems like the same
attention_generic.cuhheaders in different places are included twice. Since I don't get this problem on nvcc, do you have any clue about how to resolve this for ROCm? Thanks.#20 37.37 [3/13] Building HIP object CMakeFiles/_C.dir/csrc/attention/attention_kernels.hip.o -- | #20 37.37 FAILED: CMakeFiles/_C.dir/csrc/attention/attention_kernels.hip.o | #20 37.37 /opt/rocm/llvm/bin/clang++ -DTORCH_EXTENSION_NAME=_C -DUSE_C10D_GLOO -DUSE_C10D_NCCL -DUSE_DISTRIBUTED -DUSE_PROF_API=1 -DUSE_RPC -DUSE_TENSORPIPE -D_C_EXPORTS -D__HIP_PLATFORM_AMD__=1 -D__HIP_ROCclr__=1 -I/vllm-workspace/csrc -isystem /opt/conda/envs/py_3.9/include/python3.9 -isystem /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/include -isystem /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/include/torch/csrc/api/include -isystem /opt/rocm-6.0.0/include/hiprand -O2 -g -DNDEBUG -std=gnu++17 --offload-arch=gfx906 --offload-arch=gfx908 --offload-arch=gfx90a --offload-arch=gfx1030 --offload-arch=gfx1100 --offload-arch=gfx940 --offload-arch=gfx941 --offload-arch=gfx942 -fPIC -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -DUSE_ROCM -DENABLE_FP8 -U__HIP_NO_HALF_CONVERSIONS__ -U__HIP_NO_HALF_OPERATORS__ -fno-gpu-rdc -mllvm -amdgpu-early-inline-all=true -mllvm -amdgpu-function-calls=false -D_GLIBCXX_USE_CXX11_ABI=1 -DTORCH_HIP_VERSION=600 -Wno-shift-count-negative -Wno-shift-count-overflow -Wno-duplicate-decl-specifier -DCAFFE2_USE_MIOPEN -DTHRUST_DEVICE_SYSTEM=THRUST_DEVICE_SYSTEM_HIP -std=c++17 -MD -MT CMakeFiles/_C.dir/csrc/attention/attention_kernels.hip.o -MF CMakeFiles/_C.dir/csrc/attention/attention_kernels.hip.o.d -o CMakeFiles/_C.dir/csrc/attention/attention_kernels.hip.o -x hip -c /vllm-workspace/build/temp.linux-x86_64-cpython-39/csrc/attention/attention_kernels.hip | #20 37.37 In file included from /vllm-workspace/build/temp.linux-x86_64-cpython-39/csrc/attention/attention_kernels.hip:31: | #20 37.37 In file included from /vllm-workspace/csrc/quantization/fp8/amd/quant_utils.cuh:8: | #20 37.37 In file included from /vllm-workspace/csrc/quantization/fp8/amd/../../../attention/attention_dtypes.h:3: | #20 37.37 /vllm-workspace/csrc/quantization/fp8/amd/../../../attention/attention_generic.cuh:26:8: error: redefinition of 'Vec' | #20 37.37 struct Vec {}; | #20 37.37 ^ | #20 37.37 /vllm-workspace/build/temp.linux-x86_64-cpython-39/csrc/attention/attention_generic.cuh:26:8: note: previous definition is here | #20 37.37 struct Vec {}; | #20 37.37 ^
Can you rm -rf build and try again? May give a try soon.
I cannot do that since it's on the CI instead of my local workspace. The issue remains even after I removed the dtype_fp8.cuh header from quant_utils.cuh...
CI passed so we should be good to go. For the comment about not mixing quant_utils with reference to Fp8KVCacheDataType kv_dt in one file, we are also aware of this issue but haven't figured out a better solution yet due to the strong connection between quant_utils and attention_kernel. If needed, we could have another refactor to separate attention kernels entirely.