pytorch_scatter
pytorch_scatter copied to clipboard
Poor performance with __half and __nv_bfloat16
In atomicAdd overloads, native atomicAdd should be used for __half and __nv_bfloat16, instead of AtomicAddDecimalImpl. Like this:
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700 || CUDA_VERSION < 10000))
static inline __device__ void atomAdd(__half* address, __half val)
{
AtomicAddDecimalImpl<__half, sizeof(__half)>()(address, val);
}
#else
#if (__CUDA_ARCH__ >= 800)
static inline __device__ void atomAdd(__nv_bfloat16* address, __half val)
{
atomicAdd(address, val);
}
#endif
static inline __device__ void atomAdd(__half* address, __half val)
{
atomicAdd(address, val);
}
#endif
Do you mind sending a PR to fix?
I sure will! I have noticed, however, that when I install pytorch_scatter, I end up calling Torch's scatter_add instead anyway. Is current scatter code here obsolete ? In any case, by reading PyTorch code I learned the matter is a bit more complicated and allegedly native atomicAdd(__nv_bfloat16) is very slow so they end up using __nv_bfloat162 for it with a questionable trick - perf was not great either. So I am going to investigate options further. The most attractive option is to use __nv_bfloat162 throughout, but that would require changes to algorithm and I am not sure even possible. What do you think about that ?
Yeah, that is correct. We just use the scatter_add
implementation from PyTorch. As such, the scatter_add
implementation in torch-scatter
is indeed kinda obsolete by now.
Thanks for the confirmation! I guess there is no point of doing PR then.
If you have an idea ho to implement same scatter semantics with __nv_bfloat162 type, that may be something that PyT folks can use! As it stands now, the best thing to do is to convert bf16 to float before scatter and then convert back - way faster than trying to do it in bf16.
This issue had no activity for 6 months. It will be closed in 2 weeks unless there is some new activity. Is this issue already resolved?