pytorch_scatter
pytorch_scatter copied to clipboard
Warning: graph break when compile model with scatter_add
Hi everyone,
Im new to torch.compile, and I find this warning:
W0811 21:16:49.580000 11928 torch/_dynamo/variables/tensor.py:776] [0/0] Graph break from `Tensor.item()`, consider setting:
W0811 21:16:49.580000 11928 torch/_dynamo/variables/tensor.py:776] [0/0] torch._dynamo.config.capture_scalar_outputs = True
W0811 21:16:49.580000 11928 torch/_dynamo/variables/tensor.py:776] [0/0] or:
W0811 21:16:49.580000 11928 torch/_dynamo/variables/tensor.py:776] [0/0] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0811 21:16:49.580000 11928 torch/_dynamo/variables/tensor.py:776] [0/0] to include these operations in the captured graph.
W0811 21:16:49.580000 11928 torch/_dynamo/variables/tensor.py:776] [0/0]
W0811 21:16:49.580000 11928 torch/_dynamo/variables/tensor.py:776] [0/0] Graph break: from user code at:
W0811 21:16:49.580000 11928 torch/_dynamo/variables/tensor.py:776] [0/0] File "/workspaces/molcrafts/molpot/src/molpot/potential/base.py", line 32, in forward
W0811 21:16:49.580000 11928 torch/_dynamo/variables/tensor.py:776] [0/0] inputs = module(inputs)
W0811 21:16:49.580000 11928 torch/_dynamo/variables/tensor.py:776] [0/0] File "/workspaces/molcrafts/molpot/src/molpot/potential/nnp/pinet.py", line 343, in forward
W0811 21:16:49.580000 11928 torch/_dynamo/variables/tensor.py:776] [0/0] inputs = self.gc_blocks[i](inputs)
W0811 21:16:49.580000 11928 torch/_dynamo/variables/tensor.py:776] [0/0] File "/workspaces/molcrafts/molpot/src/molpot/potential/nnp/pinet.py", line 207, in forward
W0811 21:16:49.580000 11928 torch/_dynamo/variables/tensor.py:776] [0/0] p1, i1 = self.p1_layer(pair_i, pair_j, p1, basis)
W0811 21:16:49.580000 11928 torch/_dynamo/variables/tensor.py:776] [0/0] File "/workspaces/molcrafts/molpot/src/molpot/potential/nnp/pinet.py", line 141, in forward
W0811 21:16:49.580000 11928 torch/_dynamo/variables/tensor.py:776] [0/0] p1 = self.ip_layer(idx_i, i1)
W0811 21:16:49.580000 11928 torch/_dynamo/variables/tensor.py:776] [0/0] File "/workspaces/molcrafts/molpot/src/molpot/potential/nnp/pinet.py", line 118, in forward
W0811 21:16:49.580000 11928 torch/_dynamo/variables/tensor.py:776] [0/0] return scatter_add(inter, idx_i, dim=0)
W0811 21:16:49.580000 11928 torch/_dynamo/variables/tensor.py:776] [0/0] File "/opt/conda/lib/python3.11/site-packages/torch_scatter/scatter.py", line 29, in scatter_add
W0811 21:16:49.580000 11928 torch/_dynamo/variables/tensor.py:776] [0/0] return scatter_sum(src, index, dim, out, dim_size)
W0811 21:16:49.580000 11928 torch/_dynamo/variables/tensor.py:776] [0/0] File "/opt/conda/lib/python3.11/site-packages/torch_scatter/scatter.py", line 19, in scatter_sum
W0811 21:16:49.580000 11928 torch/_dynamo/variables/tensor.py:776] [0/0] size[dim] = int(index.max()) + 1
Is this warning matter, or how to fix this?