pytorch_geometric icon indicating copy to clipboard operation
pytorch_geometric copied to clipboard

SetTransformerAggregation raises errors in eval mode on CUDA (but not CPU)

Open delkind-dnsf opened this issue 8 months ago • 2 comments

🐛 Describe the bug

This is a weird one. I've made a simple modification of the GraphSAGE PPI example which replaces the aggregation method with the SetTransformerAggregation. We only make 2 simple changes to the example script.

This is the model.

num_node_features = 50
model = GraphSAGE(
    num_node_features,
    hidden_channels=num_node_features,
    out_channels=5,
    num_layers=2,
    act=nn.GELU(),
    aggr=SetTransformerAggregation(
        channels=num_node_features,
        num_encoder_blocks=2,
        num_decoder_blocks=2,
        heads=2,
    ),
    norm=None,
).to(device)

We have to add sorting to the encode function, or else SetTransformerAggregation complains.

@torch.no_grad()
def encode(loader):
    model.eval()

    xs, ys = [], []
    for data in loader:
        data = data.sort(sort_by_row=False)
        data.to(device)
        new_x = model(data.x, data.edge_index)
        xs.append(new_x.cpu())
        ys.append(data.y.cpu())
    return torch.cat(xs, dim=0), torch.cat(ys, dim=0)

The script fails at new_x = model(data.x, data.edge_index) on CUDA with the error message

Traceback (most recent call last):
  File "/home/ec2-user/graph/scripts/toy_graphsage_set_xform.py", line 140, in <module>
    train_f1, val_f1, test_f1 = test()
                                ^^^^^^
  File "/home/ec2-user/graph/venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/graph/scripts/toy_graphsage_set_xform.py", line 117, in test
    x, y = encode(train_loader)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/graph/venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/graph/scripts/toy_graphsage_set_xform.py", line 108, in encode
    new_x = model(data.x, data.edge_index)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/graph/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/graph/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/graph/venv/lib/python3.12/site-packages/torch_geometric/nn/models/basic_gnn.py", line 256, in forward
    x = conv(x, edge_index)
        ^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/graph/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/graph/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/graph/venv/lib/python3.12/site-packages/torch_geometric/nn/conv/sage_conv.py", line 134, in forward
    out = self.propagate(edge_index, x=x, size=size)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/torch_geometric.nn.conv.sage_conv_SAGEConv_propagate_kx3tf7d5.py", line 200, in propagate
    out = self.aggregate(
          ^^^^^^^^^^^^^^^
  File "/home/ec2-user/graph/venv/lib/python3.12/site-packages/torch_geometric/nn/conv/message_passing.py", line 594, in aggregate
    return self.aggr_module(inputs, index, ptr=ptr, dim_size=dim_size,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/graph/venv/lib/python3.12/site-packages/torch_geometric/experimental.py", line 117, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/graph/venv/lib/python3.12/site-packages/torch_geometric/nn/aggr/base.py", line 135, in __call__
    if index.numel() > 0 and dim_size <= int(index.max()):
                                             ^^^^^^^^^^^
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

This error does not occur if we use CPU instead of CUDA. It does not occur when training the model, only when evaluating the model.

Versions

Collecting environment information... PyTorch version: 2.5.0+cu124 Is debug build: False CUDA used to build PyTorch: 12.4 ROCM used to build PyTorch: N/A

OS: Amazon Linux 2023.7.20250331 (x86_64) GCC version: (GCC) 11.5.0 20240719 (Red Hat 11.5.0-5) Clang version: Could not collect CMake version: version 3.26.4 Libc version: glibc-2.34

Python version: 3.12.6 (main, Apr 6 2025, 12:21:39) [GCC 11.5.0 20240719 (Red Hat 11.5.0-5)] (64-bit runtime) Python platform: Linux-6.1.131-143.221.amzn2023.x86_64-x86_64-with-glibc2.34 Is CUDA available: True CUDA runtime version: 12.6.85 CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA L40S Nvidia driver version: 570.86.15 cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 48 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 64 On-line CPU(s) list: 0-63 Vendor ID: AuthenticAMD Model name: AMD EPYC 7R13 Processor CPU family: 25 Model: 1 Thread(s) per core: 2 Core(s) per socket: 32 Socket(s): 1 Stepping: 1 BogoMIPS: 5299.99 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch topoext invpcid_single ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr rdpru wbnoinvd arat npt nrip_save vaes vpclmulqdq rdpid Hypervisor vendor: KVM Virtualization type: full L1d cache: 1 MiB (32 instances) L1i cache: 1 MiB (32 instances) L2 cache: 16 MiB (32 instances) L3 cache: 128 MiB (4 instances) NUMA node(s): 1 NUMA node0 CPU(s): 0-63 Vulnerability Gather data sampling: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Reg file data sampling: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec rstack overflow: Mitigation; safe RET, no microcode Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected

Versions of relevant libraries: [pip3] numpy==2.2.3 [pip3] nvidia-cublas-cu12==12.4.5.8 [pip3] nvidia-cuda-cupti-cu12==12.4.127 [pip3] nvidia-cuda-nvrtc-cu12==12.4.127 [pip3] nvidia-cuda-runtime-cu12==12.4.127 [pip3] nvidia-cudnn-cu12==9.1.0.70 [pip3] nvidia-cufft-cu12==11.2.1.3 [pip3] nvidia-curand-cu12==10.3.5.147 [pip3] nvidia-cusolver-cu12==11.6.1.9 [pip3] nvidia-cusparse-cu12==12.3.1.170 [pip3] nvidia-nccl-cu12==2.21.5 [pip3] nvidia-nvjitlink-cu12==12.4.127 [pip3] nvidia-nvtx-cu12==12.4.127 [pip3] torch==2.5.0+cu124 [pip3] torch-geometric==2.6.1 [pip3] torch_scatter==2.1.2+pt25cu124 [pip3] torch_sparse==0.6.18+pt25cu124 [pip3] torchaudio==2.5.0+cu124 [pip3] torchinfo==1.8.0 [pip3] torchvision==0.20.0+cu124 [pip3] triton==3.1.0 [conda] Could not collect

delkind-dnsf avatar Apr 16 '25 14:04 delkind-dnsf

Hi @delkind-dnsf, would you mind running it with CUDA_LAUNCH_BLOCKING=1 python ... and sharing the error message here?

akihironitta avatar Apr 18 '25 08:04 akihironitta

@akihironitta This is the error message that I receive when running with the command you suggested.

(venv) [ec2-user@ip-10-0-0-119 graph]$ CUDA_LAUNCH_BLOCKING=1 python scripts/toy_graphsage_set_xform.py 

Epoch: 01, Loss: 0.6960
Traceback (most recent call last):
  File "/home/ec2-user/graph/venv/lib/python3.12/site-packages/torch_geometric/nn/aggr/base.py", line 131, in __call__
    return super().__call__(x, index=index, ptr=ptr, dim_size=dim_size,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/graph/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/graph/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/graph/venv/lib/python3.12/site-packages/torch_geometric/experimental.py", line 117, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/graph/venv/lib/python3.12/site-packages/torch_geometric/nn/aggr/set_transformer.py", line 101, in forward
    x = encoder(x, mask)
        ^^^^^^^^^^^^^^^^
  File "/home/ec2-user/graph/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/graph/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/graph/venv/lib/python3.12/site-packages/torch_geometric/nn/aggr/utils.py", line 114, in forward
    return self.mab(x, x, mask, mask)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/graph/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/graph/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/graph/venv/lib/python3.12/site-packages/torch_geometric/nn/aggr/utils.py", line 62, in forward
    out, _ = self.attn(x, y, y, y_mask, need_weights=False)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/graph/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/graph/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/graph/venv/lib/python3.12/site-packages/torch/nn/modules/activation.py", line 1308, in forward
    return torch._native_multi_head_attention(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: an illegal memory access was encountered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/ec2-user/graph/scripts/toy_graphsage_set_xform.py", line 140, in <module>
    train_f1, val_f1, test_f1 = test()
                                ^^^^^^
  File "/home/ec2-user/graph/venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/graph/scripts/toy_graphsage_set_xform.py", line 117, in test
    x, y = encode(train_loader)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/graph/venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/graph/scripts/toy_graphsage_set_xform.py", line 108, in encode
    new_x = model(data.x, data.edge_index)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/graph/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/graph/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/graph/venv/lib/python3.12/site-packages/torch_geometric/nn/models/basic_gnn.py", line 256, in forward
    x = conv(x, edge_index)
        ^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/graph/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/graph/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/graph/venv/lib/python3.12/site-packages/torch_geometric/nn/conv/sage_conv.py", line 134, in forward
    out = self.propagate(edge_index, x=x, size=size)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/torch_geometric.nn.conv.sage_conv_SAGEConv_propagate_tuaxkzt6.py", line 200, in propagate
    out = self.aggregate(
          ^^^^^^^^^^^^^^^
  File "/home/ec2-user/graph/venv/lib/python3.12/site-packages/torch_geometric/nn/conv/message_passing.py", line 594, in aggregate
    return self.aggr_module(inputs, index, ptr=ptr, dim_size=dim_size,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/graph/venv/lib/python3.12/site-packages/torch_geometric/experimental.py", line 117, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/graph/venv/lib/python3.12/site-packages/torch_geometric/nn/aggr/base.py", line 135, in __call__
    if index.numel() > 0 and dim_size <= int(index.max()):
                                             ^^^^^^^^^^^
RuntimeError: CUDA error: an illegal memory access was encountered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

delkind-dnsf avatar Apr 21 '25 13:04 delkind-dnsf