pytorch_scatter icon indicating copy to clipboard operation
pytorch_scatter copied to clipboard

Poor performance with __half and __nv_bfloat16

Open borisfom opened this issue 1 year ago • 5 comments

In atomicAdd overloads, native atomicAdd should be used for __half and __nv_bfloat16, instead of AtomicAddDecimalImpl. Like this:

#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700 || CUDA_VERSION < 10000))
static inline __device__ void atomAdd(__half* address, __half val)
{   
    AtomicAddDecimalImpl<__half, sizeof(__half)>()(address, val);
}
#else
#if (__CUDA_ARCH__ >= 800)
static inline __device__ void atomAdd(__nv_bfloat16* address, __half val)
{   
    atomicAdd(address, val);
}
#endif
static inline __device__ void atomAdd(__half* address, __half val)
{   
    atomicAdd(address, val);
}
#endif

borisfom avatar Jan 30 '24 18:01 borisfom

Do you mind sending a PR to fix?

rusty1s avatar Jan 31 '24 18:01 rusty1s

I sure will! I have noticed, however, that when I install pytorch_scatter, I end up calling Torch's scatter_add instead anyway. Is current scatter code here obsolete ? In any case, by reading PyTorch code I learned the matter is a bit more complicated and allegedly native atomicAdd(__nv_bfloat16) is very slow so they end up using __nv_bfloat162 for it with a questionable trick - perf was not great either. So I am going to investigate options further. The most attractive option is to use __nv_bfloat162 throughout, but that would require changes to algorithm and I am not sure even possible. What do you think about that ?

borisfom avatar Jan 31 '24 21:01 borisfom

Yeah, that is correct. We just use the scatter_add implementation from PyTorch. As such, the scatter_add implementation in torch-scatter is indeed kinda obsolete by now.

rusty1s avatar Feb 01 '24 07:02 rusty1s

Thanks for the confirmation! I guess there is no point of doing PR then.

borisfom avatar Feb 01 '24 08:02 borisfom

If you have an idea ho to implement same scatter semantics with __nv_bfloat162 type, that may be something that PyT folks can use! As it stands now, the best thing to do is to convert bf16 to float before scatter and then convert back - way faster than trying to do it in bf16.

borisfom avatar Feb 01 '24 08:02 borisfom

This issue had no activity for 6 months. It will be closed in 2 weeks unless there is some new activity. Is this issue already resolved?

github-actions[bot] avatar Jul 31 '24 01:07 github-actions[bot]