candle
candle copied to clipboard
Manually perform 8-bit dot product (__dp4a)
Hi, The candle-kernel crate throws an error on the Maxwell architecture when migrating code that works fine in PyTorch. It seems the only problem is '__dp4a' which used for performing four 8-bit integer dot product operations for quantization. It would be great if the kernel crate has implementation to compute the dot product for 'CUDA_ARCH < 610'.
#if __CUDA_ARCH__ < 610
// Manually perform 8-bit dot product
__device__ int manual_dp4a(int a, int b, int c) {
int result = c;
for (int i = 0; i < 4; ++i) {
int8_t a_byte = (a >> (i * 8)) & 0xFF;
int8_t b_byte = (b >> (i * 8)) & 0xFF;
result += a_byte * b_byte;
}
return result;
}
#endif
#if __CUDA_ARCH__ >= 610
...
sumi = __dp4a(vi0, u[2*i+0], sumi);
...
#else
...
sumi = manual_dp4a(vi0, u[2*i+0], sumi);
...
#endif