pytorch_scatter
pytorch_scatter copied to clipboard
scatter_add is slow with fp16 inputs when the index tensor is concentrated
Potential Bug
When given fp16 inputs and a concentrated index tensor (with repeating indices), the scatter_add function becomes very slow. This does not happen for fp32 inputs.

Below is a minimal code sample to reproduce the problem by plotting the slow down as the index tensor becomes more concentrated:
import torch
import time
import matplotlib.pyplot as plt
from torch_scatter import scatter_add
batch_size = 100
tgt_dim = 1000000
src_dim = 300000
max_spread_pow = 10
fp32_results = []
fp16_results = []
for spread_pow in range(1, max_spread_pow):
for precision in ["fp32", "fp16"]:
dtype = torch.float32 if precision == "fp32" else torch.float16
tgt = torch.zeros(batch_size, tgt_dim).to(dtype).cuda()
src = torch.rand(batch_size, src_dim).to(dtype).cuda()
spread = 0.5 ** spread_pow
index = torch.randint(
int(tgt_dim / 2 - tgt_dim * spread),
int(tgt_dim / 2 + tgt_dim * spread),
(batch_size, src_dim)
).cuda()
t_mean = 0
for _ in range(10):
torch.cuda.synchronize()
t0 = time.time()
scatter_add(src, index, 1, tgt)
torch.cuda.synchronize()
t1 = time.time()
t_mean += t1 - t0
t_mean = t_mean / 100
if precision == "fp32":
fp32_results.append(t_mean)
else:
fp16_results.append(t_mean)
fig = plt.figure()
plt.plot(range(1, max_spread_pow), fp32_results, label="fp32")
plt.plot(range(1, max_spread_pow), fp16_results, label="fp16")
plt.xlabel("spread (%)")
plt.xticks(range(1, max_spread_pow),
[f"{2 * 0.5 ** spread_pow:.5f}" for spread_pow in range(1, max_spread_pow)],
rotation="vertical")
plt.ylabel("time (sec)")
plt.legend()
fig.subplots_adjust(bottom=0.2)
plt.savefig("scatter_add_issue.png")
plt.show()
Versions
PyTorch version: 1.9.0
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A
OS: Ubuntu 16.04.6 LTS (x86_64)
GCC version: (Ubuntu 6.5.0-2ubuntu1~16.04) 6.5.0 20181026
Clang version: 3.9.1-4ubuntu3~16.04.2 (tags/RELEASE_391/rc2)
CMake version: Could not collect
Libc version: glibc-2.17
Python version: 3.7.11 (default, Jul 27 2021, 14:32:16) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-4.15.0-142-generic-x86_64-with-debian-stretch-sid
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: TITAN V
GPU 1: GeForce GTX 1080 Ti
Nvidia driver version: 440.33.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] msgpack-numpy==0.4.7.1
[pip3] numpy==1.21.5
[pip3] numpy-quaternion==2022.2.10.14.20.39
[pip3] torch==1.9.0
[pip3] torch-scatter==2.0.8
[pip3] torchvision==0.10.0
[conda] blas 1.0 mkl
[conda] cudatoolkit 10.2.89 hfd86e86_1
[conda] ffmpeg 4.3 hf484d3e_0 pytorch
[conda] libblas 3.9.0 13_linux64_mkl conda-forge
[conda] libcblas 3.9.0 13_linux64_mkl conda-forge
[conda] liblapack 3.9.0 13_linux64_mkl conda-forge
[conda] mkl 2022.0.1 h06a4308_117
[conda] msgpack-numpy 0.4.7.1 pypi_0 pypi
[conda] numpy 1.21.5 py37hf2998dd_0 conda-forge
[conda] pytorch 1.9.0 py3.7_cuda10.2_cudnn7.6.5_0 pytorch
[conda] torch-scatter 2.0.8 pypi_0 pypi
[conda] torchvision 0.10.0 py37_cu102 pytorch
Thanks for reporting. Interesting finding! For scatter_add, we make use of PyTorch's internal torch.scatter_add_ functionality, so there is not much we can do to improve performance on our end despite reporting this issue to the PyTorch team as well.
Does the same also hold true for other reductions, e.g., via max or min?
This issue had no activity for 6 months. It will be closed in 2 weeks unless there is some new activity. Is this issue already resolved?