dgl icon indicating copy to clipboard operation
dgl copied to clipboard

`CUDA error: unspecified launch failure`, similar to #3802

Open samvanstroud opened this issue 3 years ago • 7 comments

🐛 Bug

I am seeing the same issue that was reported as fixed in #3841 in the latest 0.9.0 (and everything down to releases lower than 0.8.0. See #3802 for more context. As previous reported by @wsjeon, I am seeing this issue using DGL with pytorch lightning, though I haven't tried to see if I can reproduce the problem without using this package.

Tagging @BarclayII and @nv-dlasalle who previously investigated this.

To Reproduce

Steps to reproduce the behavior:

  1. Setup environment
conda config --env --add channels dglteam 
conda config --env --add channels pytorch
conda install dgl-cuda11.3 pytorch-lightning cudatoolkit=11.3 pytorch=1.12.1
  1. Run this:
import torch
import dgl
import pytorch_lightning as pl

class MyModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(10, 10)

    def training_step(self, batch, batch_nb):
        return torch.tensor(2)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)

class MyDataset(torch.utils.data.Dataset):
    def __init__(self):
        super().__init__()
        
    def __len__(self):
        return 10

    def __getitem__(self, idx):
        g = dgl.graph(data=([0,1],[1,0]), num_nodes=2)
        return g, torch.tensor([0])

def collate_graphs(samples):
    graphs = [x[0] for x in samples]
    batched_graph = dgl.batch(graphs)
    targets = torch.cat([x[1] for x in samples])
    return batched_graph, targets

loader = torch.utils.data.DataLoader(dataset=MyDataset(), batch_size=2, num_workers=2, collate_fn=collate_graphs)
model = MyModel()

trainer = pl.Trainer(
    strategy='ddp',
    accelerator='gpu',
    devices=[0],
    fast_dev_run=True,
)

trainer.fit(model, loader)

Stack trace:

Epoch 0:   0%|                                                                                                                                                       | 0/1 [00:00<?, ?it/s]Traceback (most recent call last):
  File "miniconda3/envs/dgl-test/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 721, in _call_and_handle_interrupt
    return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs)
  File "miniconda3/envs/dgl-test/lib/python3.10/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 93, in launch
    return function(*args, **kwargs)
  File "miniconda3/envs/dgl-test/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 811, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "miniconda3/envs/dgl-test/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1236, in _run
    results = self._run_stage()
  File "miniconda3/envs/dgl-test/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1323, in _run_stage
    return self._run_train()
  File "miniconda3/envs/dgl-test/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1353, in _run_train
    self.fit_loop.run()
  File "miniconda3/envs/dgl-test/lib/python3.10/site-packages/pytorch_lightning/loops/base.py", line 204, in run
    self.advance(*args, **kwargs)
  File "miniconda3/envs/dgl-test/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 266, in advance
    self._outputs = self.epoch_loop.run(self._data_fetcher)
  File "miniconda3/envs/dgl-test/lib/python3.10/site-packages/pytorch_lightning/loops/base.py", line 204, in run
    self.advance(*args, **kwargs)
  File "miniconda3/envs/dgl-test/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 171, in advance
    batch = next(data_fetcher)
  File "miniconda3/envs/dgl-test/lib/python3.10/site-packages/pytorch_lightning/utilities/fetching.py", line 184, in __next__
    return self.fetching_function()
  File "miniconda3/envs/dgl-test/lib/python3.10/site-packages/pytorch_lightning/utilities/fetching.py", line 269, in fetching_function
    return self.move_to_device(batch)
  File "miniconda3/envs/dgl-test/lib/python3.10/site-packages/pytorch_lightning/utilities/fetching.py", line 284, in move_to_device
    batch = self.batch_to_device(batch)
  File "miniconda3/envs/dgl-test/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1765, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "miniconda3/envs/dgl-test/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 230, in batch_to_device
    return model._apply_batch_transfer_handler(batch, device=device, dataloader_idx=dataloader_idx)
  File "miniconda3/envs/dgl-test/lib/python3.10/site-packages/pytorch_lightning/core/lightning.py", line 291, in _apply_batch_transfer_handler
    batch = hook(batch, device, dataloader_idx)
  File "miniconda3/envs/dgl-test/lib/python3.10/site-packages/pytorch_lightning/core/hooks.py", line 713, in transfer_batch_to_device
    return move_data_to_device(batch, device)
  File "miniconda3/envs/dgl-test/lib/python3.10/site-packages/pytorch_lightning/utilities/apply_func.py", line 354, in move_data_to_device
    return apply_to_collection(batch, dtype=dtype, function=batch_to)
  File "miniconda3/envs/dgl-test/lib/python3.10/site-packages/pytorch_lightning/utilities/apply_func.py", line 121, in apply_to_collection
    v = apply_to_collection(
  File "miniconda3/envs/dgl-test/lib/python3.10/site-packages/pytorch_lightning/utilities/apply_func.py", line 99, in apply_to_collection
    return function(data, *args, **kwargs)
  File "miniconda3/envs/dgl-test/lib/python3.10/site-packages/pytorch_lightning/utilities/apply_func.py", line 347, in batch_to
    data_output = data.to(device, **kwargs)
  File "miniconda3/envs/dgl-test/lib/python3.10/site-packages/dgl/heterograph.py", line 5448, in to
    ret._graph = self._graph.copy_to(utils.to_dgl_context(device))
  File "miniconda3/envs/dgl-test/lib/python3.10/site-packages/dgl/heterograph_index.py", line 236, in copy_to
    return _CAPI_DGLHeteroCopyTo(self, ctx.device_type, ctx.device_id)
  File "dgl/_ffi/_cython/./function.pxi", line 293, in dgl._ffi._cy3.core.FunctionBase.__call__
  File "dgl/_ffi/_cython/./function.pxi", line 225, in dgl._ffi._cy3.core.FuncCall
  File "dgl/_ffi/_cython/./function.pxi", line 215, in dgl._ffi._cy3.core.FuncCall3
dgl._ffi.base.DGLError: [12:47:27] /opt/dgl/src/runtime/cuda/cuda_device_api.cc:114: Check failed: e == cudaSuccess || e == cudaErrorCudartUnloading: CUDA: unspecified launch failure
Stack trace:
  [bt] (0) miniconda3/envs/dgl-test/lib/python3.10/site-packages/dgl/libdgl.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x4f) [0x7f51d135fd6f]
  [bt] (1) miniconda3/envs/dgl-test/lib/python3.10/site-packages/dgl/libdgl.so(dgl::runtime::CUDADeviceAPI::AllocDataSpace(DLContext, unsigned long, unsigned long, DLDataType)+0x108) [0x7f51d183a4a8]
  [bt] (2) miniconda3/envs/dgl-test/lib/python3.10/site-packages/dgl/libdgl.so(dgl::runtime::NDArray::Empty(std::vector<long, std::allocator<long> >, DLDataType, DLContext)+0x361) [0x7f51d16ac5d1]
  [bt] (3) miniconda3/envs/dgl-test/lib/python3.10/site-packages/dgl/libdgl.so(dgl::runtime::NDArray::CopyTo(DLContext const&, void* const&) const+0xc7) [0x7f51d16e8bb7]
  [bt] (4) miniconda3/envs/dgl-test/lib/python3.10/site-packages/dgl/libdgl.so(dgl::UnitGraph::CopyTo(std::shared_ptr<dgl::BaseHeteroGraph>, DLContext const&, void* const&)+0x317) [0x7f51d17f9db7]
  [bt] (5) miniconda3/envs/dgl-test/lib/python3.10/site-packages/dgl/libdgl.so(dgl::HeteroGraph::CopyTo(std::shared_ptr<dgl::BaseHeteroGraph>, DLContext const&, void* const&)+0x109) [0x7f51d16fa939]
  [bt] (6) miniconda3/envs/dgl-test/lib/python3.10/site-packages/dgl/libdgl.so(+0x73b9c9) [0x7f51d17079c9]
  [bt] (7) miniconda3/envs/dgl-test/lib/python3.10/site-packages/dgl/libdgl.so(DGLFuncCall+0x48) [0x7f51d168a928]
  [bt] (8) miniconda3/envs/dgl-test/lib/python3.10/site-packages/dgl/_ffi/_cy3/core.cpython-310-x86_64-linux-gnu.so(+0x16143) [0x7f51f4995143]



During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "gnn-tagger/GNNJetTagger/gnn_tagger/training/minimal.py", line 43, in <module>
    trainer.fit(model, loader)
  File "miniconda3/envs/dgl-test/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 770, in fit
    self._call_and_handle_interrupt(
  File "miniconda3/envs/dgl-test/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 738, in _call_and_handle_interrupt
    self._teardown()
  File "miniconda3/envs/dgl-test/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1300, in _teardown
    self.strategy.teardown()
  File "miniconda3/envs/dgl-test/lib/python3.10/site-packages/pytorch_lightning/strategies/ddp.py", line 482, in teardown
    self.lightning_module.cpu()
  File "miniconda3/envs/dgl-test/lib/python3.10/site-packages/pytorch_lightning/core/mixins/device_dtype_mixin.py", line 147, in cpu
    return super().cpu()
  File "miniconda3/envs/dgl-test/lib/python3.10/site-packages/torch/nn/modules/module.py", line 738, in cpu
    return self._apply(lambda t: t.cpu())
  File "miniconda3/envs/dgl-test/lib/python3.10/site-packages/torch/nn/modules/module.py", line 579, in _apply
    module._apply(fn)
  File "miniconda3/envs/dgl-test/lib/python3.10/site-packages/torch/nn/modules/module.py", line 602, in _apply
    param_applied = fn(param)
  File "miniconda3/envs/dgl-test/lib/python3.10/site-packages/torch/nn/modules/module.py", line 738, in <lambda>
    return self._apply(lambda t: t.cpu())
RuntimeError: CUDA error: unspecified launch failure

Expected behavior

Environment

  • DGL Version (e.g., 1.0): dgl-cuda11.3 0.9.0 py310_0
  • Backend Library & Version (e.g., PyTorch 0.4.1, MXNet/Gluon 1.3): pytorch 1.12.1 py3.10_cuda11.3_cudnn8.3.2_0
  • OS (e.g., Linux): Linux
  • How you installed DGL (conda, pip, source): conda
  • Build command you used (if compiling from source): NA
  • Python version: 3.10
  • CUDA/cuDNN version (if applicable): 11.3
  • GPU models and configuration (e.g. V100): GeForce RTX 2080
  • Any other relevant information:

samvanstroud avatar Aug 08 '22 11:08 samvanstroud

Issue confirmed in the GraphSAGE official example with multi-worker CPU sampling and DataLoader device specified as CUDA:

diff --git a/examples/pytorch/graphsage/node_classification.py b/examples/pytorch/graphsage/node_classification.py
index 72054094..5e073a6d 100644
--- a/examples/pytorch/graphsage/node_classification.py
+++ b/examples/pytorch/graphsage/node_classification.py
@@ -71,11 +71,11 @@ split_idx = dataset.get_idx_split()
 train_idx, valid_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test']

 device = 'cuda'
-train_idx = train_idx.to(device)
-valid_idx = valid_idx.to(device)
-test_idx = test_idx.to(device)
+#train_idx = train_idx.to(device)
+#valid_idx = valid_idx.to(device)
+#test_idx = test_idx.to(device)

-graph = graph.to('cuda' if args.pure_gpu else 'cpu')
+#graph = graph.to('cuda' if args.pure_gpu else 'cpu')

 model = SAGE(graph.ndata['feat'].shape[1], 256, dataset.num_classes).to(device)
 opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
@@ -84,10 +84,10 @@ sampler = dgl.dataloading.NeighborSampler(
         [15, 10, 5], prefetch_node_feats=['feat'], prefetch_labels=['label'])
 train_dataloader = dgl.dataloading.DataLoader(
         graph, train_idx, sampler, device=device, batch_size=1024, shuffle=True,
-        drop_last=False, num_workers=0, use_uva=not args.pure_gpu)
+        drop_last=False, num_workers=12, use_uva=False)
 valid_dataloader = dgl.dataloading.DataLoader(
         graph, valid_idx, sampler, device=device, batch_size=1024, shuffle=True,
-        drop_last=False, num_workers=0, use_uva=not args.pure_gpu)
+        drop_last=False, num_workers=12, use_uva=False)

 durations = []
 for _ in range(10):

The code worked in 0.8.2post1.

BarclayII avatar Aug 08 '22 13:08 BarclayII

@BarclayII Cannot repro with the GraphSAGE example and dgl 0.9.0. Multi-worker CPU sampling and CUDA dataloader device should have been covered in the unit test now. https://github.com/dmlc/dgl/blob/5ba5106acab6a642e9b790e5331ee519112a5623/tests/pytorch/test_dataloader.py#L185-L187

@samvanstroud Are you using PyTorch 1.12.1? I don't think DGL has released PyTorch 1.12.1 support. Can you try PyTorch 1.12.0?

yaox12 avatar Aug 09 '22 07:08 yaox12

Thanks @yaox12, upon downgrading to PT 1.12.0 the issue is resolved.

samvanstroud avatar Aug 09 '22 10:08 samvanstroud

@yaox12 Do you know why pytorch 1.12.1 would cause this? This looks like an issue of a forked cuda context, not related to the TensorAdapter (which has issues when we haven't compiled against a given PyTorch version).

nv-dlasalle avatar Aug 09 '22 17:08 nv-dlasalle

@yaox12 Do you know why pytorch 1.12.1 would cause this? This looks like an issue of a forked cuda context, not related to the TensorAdapter (which has issues when we haven't compiled against a given PyTorch version).

@nv-dlasalle I encountered a core dump with PyTorch 1.12.1. Haven't investigated what happened.

yaox12 avatar Aug 10 '22 00:08 yaox12

@yaox12 Are you investigating this?

mufeili avatar Aug 15 '22 06:08 mufeili

@mufeili I can reproduce this issue with PyTorch 1.12.1, but haven't found the root cause. Regarding the error message, it seems not related to the tensoradaptor so I'm not sure what changes in PyTorch 1.12.1 break it. I'll try building from source with PyTorch 1.12.1 and see if the error goes away.

Update: The error disappears when building DGL from source with PyTorch 1.12.1.

yaox12 avatar Aug 15 '22 06:08 yaox12

We can build a new minor release with 1.12.1 support. But I think we still need to investigate why turning off TensorAdapter will crash the code.

BarclayII avatar Aug 22 '22 06:08 BarclayII

I can reproduce the issue with pytorch 1.9.0 if I delete the tensoradapter shared library.

I wonder if this is related to #4135?

nv-dlasalle avatar Aug 22 '22 18:08 nv-dlasalle

@nv-dlasalle Good catch! This is my fault. Should be fixed in #4450.

yaox12 avatar Aug 23 '22 01:08 yaox12