tvm icon indicating copy to clipboard operation
tvm copied to clipboard

[TIR][CUDA] Add native FP8 support to codegen

Open csullivan opened this issue 1 year ago • 1 comments

Adds native FP8 type support for CUDA. The e4m3/e5m2 struct types provide explicit type conversions that target hardware native conversion ops.

* Conditionally run Storage and Compute legalization for targets that don't support FP8. This could be changed to only support conversion operators and do legalization on any compute operations other than builtin wmma calls.

* Implement support for float16x4 (half4) for use with e4m3_float8x4 (__nv_fp8x4_e4m3)

e.g.

#include <cuda_fp8.h>
using fp8_e4_t = __nv_fp8_e4m3;
using fp8_e4_2_t = __nv_fp8x2_e4m3;
using fp8_e4_4_t = __nv_fp8x4_e4m3;
using fp8_e5_t = __nv_fp8_e5m2;
using fp8_e5_2_t = __nv_fp8x2_e5m2;
using fp8_e5_4_t = __nv_fp8x4_e5m2;

struct __align__(4) half4 {
  __half x, y, z, w;
  __host__ __device__ half4() : x(__half(0)), y(__half(0)), z(__half(0)), w(__half(0)) {}
  __host__ __device__ half4(__half x, __half y, __half z, __half w) : x(x), y(y), z(z), w(w) {}
  __host__ __device__ explicit half4(const __nv_fp8x4_e4m3& fp8x4) {
    __nv_fp8x2_e4m3 lo_part, hi_part;
    lo_part.__x = static_cast<__nv_fp8x2_storage_t>(fp8x4.__x & 0xFFFF);
    hi_part.__x = static_cast<__nv_fp8x2_storage_t>((fp8x4.__x >> 16) & 0xFFFF);
    __half2 lo_half2 = static_cast<__half2>(lo_part);
    __half2 hi_half2 = static_cast<__half2>(hi_part);
    x = reinterpret_cast<__half*>(&lo_half2)[0];
    y = reinterpret_cast<__half*>(&lo_half2)[1];
    z = reinterpret_cast<__half*>(&hi_half2)[0];
    w = reinterpret_cast<__half*>(&hi_half2)[1];
  }
  __host__ __device__ explicit operator __nv_fp8x4_e4m3() const {
    __nv_fp8x4_e4m3 result;
    __half2 lo_half2 = *reinterpret_cast<const __half2*>(&x);
    __half2 hi_half2 = *reinterpret_cast<const __half2*>(&z);
    __nv_fp8x2_e4m3 lo_part(lo_half2), hi_part(hi_half2);
    result.__x =
        (static_cast<__uint32_t>(lo_part.__x) | (static_cast<__uint32_t>(hi_part.__x) << 16));
    return result;
  }
};
#endif


extern "C" __global__ void __launch_bounds__(32) add_kernel(fp8_e4_4_t* __restrict__ A, fp8_e4_4_t* __restrict__ B, fp8_e4_4_t* __restrict__ C) {
  half4 __1;
    half4 v_ = (half4)(A[((((int)blockIdx.x) * 32) + ((int)threadIdx.x))]);
    half4 v__1 = (half4)(B[((((int)blockIdx.x) * 32) + ((int)threadIdx.x))]);
    __1.x = (v_.x+v__1.x);
    __1.y = (v_.y+v__1.y);
    __1.z = (v_.z+v__1.z);
    __1.w = (v_.w+v__1.w);
  C[((((int)blockIdx.x) * 32) + ((int)threadIdx.x))] = (fp8_e4_4_t)(__1);
}
}

csullivan avatar Feb 08 '24 20:02 csullivan

please checkin on the ci issues, likely we need requires_cuda tag?

tqchen avatar Feb 15 '24 13:02 tqchen

@tvm-bot rerun

tqchen avatar Mar 11 '24 17:03 tqchen