pytorch_scatter
pytorch_scatter copied to clipboard
scatter or scatter_min fails when using torch.compile
Hello,
I can't compile any model that includes scatter or scatter min from torch_scatter. For example in this beautiful script
import torch
import torch_geometric
from torch_scatter import scatter_min
print("the version of torch", torch.__version__)
print("torch_geometric version", torch_geometric.__version__)
def get_x(n_points=100):
import torch
x_min = [0, 10]
y_min = [0, 10]
z_min = [0, 10]
x = torch.rand((n_points, 3))
x[:, 0] = x[:, 0] * (x_min[1] - x_min[0]) + x_min[0]
x[:, 1] = x[:, 1] * (y_min[1] - y_min[0]) + y_min[0]
x[:, 2] = x[:, 2] * (z_min[1] - z_min[0]) + z_min[0]
return x
device = "cuda"
x = get_x(n_points=10)
se = torch.randint(low=0, high=10, size=(10,))
model = scatter_min
compiled_model = torch.compile(model)
expected `= model(x, se, dim=0)
out = compiled_model(x, se, dim=0)
assert torch.allclose(out, expected, atol=1e-6)
The code fails with :
torch._dynamo.exc.TorchRuntimeError: Failed running call_function torch_scatter.scatter_min(*(FakeTensor(..., size=(10, 3)), FakeTensor(..., size=(10,), dtype=torch.int64), 0, None, None), **{}):
The tensor has a non-zero number of elements, but its data is not allocated yet. Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data() to actually allocate memory.
from user code:
line 65, in scatter_min
return torch.ops.torch_scatter.scatter_min(src, index, dim, out, dim_size)
My torch version is 2.2.0 torch_geometric 2.5.2 and torch_scatter is 2.1.2,
This is currently expected, since the custom ops by torch-scatter are not supported in torch.compile. There exists two options:
- Disallow the use of
compilefor certain ops - Fallback to https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_reduce_.html#torch.Tensor.scatter_reduce_. This is also what we are doing on PyG side.
For this, we added utils.scatter, which will pick up the best computation path depending on your input arguments. Also works with torch.compile.
For this, we added
utils.scatter, which will pick up the best computation path depending on your input arguments. Also works withtorch.compile.
If I understand correctly, you suggest that instead of using torch_sum or torch_scatter, we should use by default utils.scatter instead of directly calling scatter_min or scatter_max ?
Yes, if you want torch.compile support, then this is the recommended way.