pytorch_scatter icon indicating copy to clipboard operation
pytorch_scatter copied to clipboard

scatter or scatter_min fails when using torch.compile

Open gardiens opened this issue 1 year ago • 6 comments
trafficstars

Hello,

I can't compile any model that includes scatter or scatter min from torch_scatter. For example in this beautiful script

  import torch
import torch_geometric
from torch_scatter import scatter_min

print("the version of torch", torch.__version__)
print("torch_geometric version", torch_geometric.__version__)


def get_x(n_points=100):  
    import torch

    x_min = [0, 10]
    y_min = [0, 10]
    z_min = [0, 10]

    x = torch.rand((n_points, 3))
    x[:, 0] = x[:, 0] * (x_min[1] - x_min[0]) + x_min[0]
    x[:, 1] = x[:, 1] * (y_min[1] - y_min[0]) + y_min[0]
    x[:, 2] = x[:, 2] * (z_min[1] - z_min[0]) + z_min[0]

    return x


device = "cuda"
x = get_x(n_points=10)
se = torch.randint(low=0, high=10, size=(10,))

model = scatter_min
compiled_model = torch.compile(model)

expected  `= model(x, se, dim=0)
out = compiled_model(x, se, dim=0)
assert torch.allclose(out, expected, atol=1e-6)

The code fails with :

 torch._dynamo.exc.TorchRuntimeError: Failed running call_function torch_scatter.scatter_min(*(FakeTensor(..., size=(10, 3)), FakeTensor(..., size=(10,), dtype=torch.int64), 0, None, None), **{}):
The tensor has a non-zero number of elements, but its data is not allocated yet. Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data() to actually allocate memory.

from user code:
 line 65, in scatter_min
    return torch.ops.torch_scatter.scatter_min(src, index, dim, out, dim_size)

My torch version is 2.2.0 torch_geometric 2.5.2 and torch_scatter is 2.1.2,

gardiens avatar May 02 '24 09:05 gardiens

This is currently expected, since the custom ops by torch-scatter are not supported in torch.compile. There exists two options:

  • Disallow the use of compile for certain ops
  • Fallback to https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_reduce_.html#torch.Tensor.scatter_reduce_. This is also what we are doing on PyG side.

rusty1s avatar May 07 '24 13:05 rusty1s

For this, we added utils.scatter, which will pick up the best computation path depending on your input arguments. Also works with torch.compile.

rusty1s avatar May 07 '24 13:05 rusty1s

For this, we added utils.scatter, which will pick up the best computation path depending on your input arguments. Also works with torch.compile.

If I understand correctly, you suggest that instead of using torch_sum or torch_scatter, we should use by default utils.scatter instead of directly calling scatter_min or scatter_max ?

gardiens avatar May 10 '24 09:05 gardiens

Yes, if you want torch.compile support, then this is the recommended way.

rusty1s avatar May 10 '24 13:05 rusty1s