pytorch_scatter icon indicating copy to clipboard operation
pytorch_scatter copied to clipboard

scatter_add is slow with fp16 inputs when the index tensor is concentrated

Open theophilegervet opened this issue 3 years ago • 2 comments

Potential Bug

When given fp16 inputs and a concentrated index tensor (with repeating indices), the scatter_add function becomes very slow. This does not happen for fp32 inputs.

scatter_add_issue

Below is a minimal code sample to reproduce the problem by plotting the slow down as the index tensor becomes more concentrated:

import torch
import time
import matplotlib.pyplot as plt
from torch_scatter import scatter_add

batch_size = 100
tgt_dim = 1000000
src_dim = 300000
max_spread_pow = 10
fp32_results = []
fp16_results = []

for spread_pow in range(1, max_spread_pow):
    for precision in ["fp32", "fp16"]:
        dtype = torch.float32 if precision == "fp32" else torch.float16
        tgt = torch.zeros(batch_size, tgt_dim).to(dtype).cuda()
        src = torch.rand(batch_size, src_dim).to(dtype).cuda()
        spread = 0.5 ** spread_pow
        index = torch.randint(
            int(tgt_dim / 2 - tgt_dim * spread),
            int(tgt_dim / 2 + tgt_dim * spread),
            (batch_size, src_dim)
        ).cuda()

        t_mean = 0
        for _ in range(10):
            torch.cuda.synchronize()
            t0 = time.time()
            scatter_add(src, index, 1, tgt)
            torch.cuda.synchronize()
            t1 = time.time()
            t_mean += t1 - t0
        t_mean = t_mean / 100

        if precision == "fp32":
            fp32_results.append(t_mean)
        else:
            fp16_results.append(t_mean)

fig = plt.figure()
plt.plot(range(1, max_spread_pow), fp32_results, label="fp32")
plt.plot(range(1, max_spread_pow), fp16_results, label="fp16")
plt.xlabel("spread (%)")
plt.xticks(range(1, max_spread_pow),
           [f"{2 * 0.5 ** spread_pow:.5f}" for spread_pow in range(1, max_spread_pow)],
           rotation="vertical")
plt.ylabel("time (sec)")
plt.legend()
fig.subplots_adjust(bottom=0.2)
plt.savefig("scatter_add_issue.png")
plt.show()

Versions

PyTorch version: 1.9.0
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 16.04.6 LTS (x86_64)
GCC version: (Ubuntu 6.5.0-2ubuntu1~16.04) 6.5.0 20181026
Clang version: 3.9.1-4ubuntu3~16.04.2 (tags/RELEASE_391/rc2)
CMake version: Could not collect
Libc version: glibc-2.17

Python version: 3.7.11 (default, Jul 27 2021, 14:32:16)  [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-4.15.0-142-generic-x86_64-with-debian-stretch-sid
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration: 
GPU 0: TITAN V
GPU 1: GeForce GTX 1080 Ti

Nvidia driver version: 440.33.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] msgpack-numpy==0.4.7.1
[pip3] numpy==1.21.5
[pip3] numpy-quaternion==2022.2.10.14.20.39
[pip3] torch==1.9.0
[pip3] torch-scatter==2.0.8
[pip3] torchvision==0.10.0
[conda] blas                      1.0                         mkl  
[conda] cudatoolkit               10.2.89              hfd86e86_1  
[conda] ffmpeg                    4.3                  hf484d3e_0    pytorch
[conda] libblas                   3.9.0            13_linux64_mkl    conda-forge
[conda] libcblas                  3.9.0            13_linux64_mkl    conda-forge
[conda] liblapack                 3.9.0            13_linux64_mkl    conda-forge
[conda] mkl                       2022.0.1           h06a4308_117  
[conda] msgpack-numpy             0.4.7.1                  pypi_0    pypi
[conda] numpy                     1.21.5           py37hf2998dd_0    conda-forge
[conda] pytorch                   1.9.0           py3.7_cuda10.2_cudnn7.6.5_0    pytorch
[conda] torch-scatter             2.0.8                    pypi_0    pypi
[conda] torchvision               0.10.0               py37_cu102    pytorch

theophilegervet avatar Mar 21 '22 19:03 theophilegervet

Thanks for reporting. Interesting finding! For scatter_add, we make use of PyTorch's internal torch.scatter_add_ functionality, so there is not much we can do to improve performance on our end despite reporting this issue to the PyTorch team as well.

Does the same also hold true for other reductions, e.g., via max or min?

rusty1s avatar Mar 22 '22 06:03 rusty1s

This issue had no activity for 6 months. It will be closed in 2 weeks unless there is some new activity. Is this issue already resolved?

github-actions[bot] avatar Sep 19 '22 02:09 github-actions[bot]