lmdeploy icon indicating copy to clipboard operation
lmdeploy copied to clipboard

[Bug] RuntimeError: CUDA error: operation not permitted when stream is capturing

Open LinJianping opened this issue 1 year ago • 15 comments

Checklist

  • [ ] 1. I have searched related issues but cannot get the expected help.
  • [ ] 2. The bug has not been fixed in the latest version.
  • [ ] 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.

Describe the bug

使用lmdeploy v0.6.0加载InternVL2-1B,在循环中执行推理会报“RuntimeError: CUDA error: operation not permitted when stream is capturing”,怀疑跟v0.6.0支持cuda graph有关。

Reproduction

` import os import time import torch from lmdeploy import pipeline, TurbomindEngineConfig from lmdeploy.vl import load_image

device = "cuda" pwd = os.path.abspath(os.path.dirname(file)) model_path = os.path.join(pwd, 'InternVL2-1B') pipe = pipeline(model_path, backend_config=TurbomindEngineConfig(cache_max_entry_count=0.6))

BATCH_SIZE = 8

querys = [ '图片中有海吗', ]*BATCH_SIZE

image_paths = [os.path.join(pwd, "warmup/flag.jpg")]*BATCH_SIZE

image = load_image(image_paths[1]) response = pipe((querys[1], image)) prompts = [(query, load_image(img_url)) for img_url, query in zip(image_paths, querys)] response = pipe(prompts) print(response) _REPEAT = 100 tic = time.time() torch.cuda.synchronize() for _ in range(_REPEAT): response = pipe(prompts) torch.cuda.synchronize() toc = time.time() print(response) print(f'seconds per image:{(toc-tic)/BATCH_SIZE/_REPEAT}') `

Environment

sys.platform: linux
Python: 3.9.16 (main, Apr  2 2024, 20:40:25) [GCC 10.2.1 20210130 (Red Hat 10.2.1-11)]
CUDA available: True
MUSA available: False
numpy_random_seed: 2147483648
GPU 0: NVIDIA L40
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 12.2, V12.2.91
GCC: gcc (GCC) 4.8.5 20150623 (Red Hat 4.8.5-44)
PyTorch: 2.2.1
PyTorch compiling details: PyTorch built with:
  - GCC 10.2
  - C++ Version: 201703
  - Intel(R) oneAPI Math Kernel Library Version 2024.1-Product Build 20240215 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v3.3.2 (Git Hash 2dc95a2ad0841e29db8b22fbccaf3e5da7992b01)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: AVX512
  - CUDA Runtime 12.2
  - NVCC architecture flags: -gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_89,code=sm_89;-gencode;arch=compute_90,code=sm_90
  - CuDNN 8.9.6
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=12.2, CUDNN_VERSION=8.9.6, CXX_COMPILER=/opt/rh/devtoolset-10/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wsuggest-override -Wno-psabi -Wno-error=pedantic -Wno-error=old-style-cast -Wno-missing-braces -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=2.2.1, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, USE_ROCM_KERNEL_ASSERT=OFF,

TorchVision: 0.17.1
LMDeploy: 0.6.0+
transformers: 4.41.2
gradio: Not Found
fastapi: 0.111.0
pydantic: 2.7.4
triton: 2.2.0
NVIDIA Topology:
	GPU0	NIC0	NIC1	CPU Affinity	NUMA Affinity	GPU NUMA ID
GPU0	 X 	SYS	SYS				N/A
NIC0	SYS	 X 	SYS
NIC1	SYS	SYS	 X

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

NIC Legend:

  NIC0: mlx5_bond_0
  NIC1: mlx5_bond_1

Error traceback

Exception in callback _raise_exception_on_finish(<Future finis...sertions.\n')>) at /home/hadoop-platcv/.local/lib/python3.9/site-packages/lmdeploy/vl/engine.py:20
handle: <Handle _raise_exception_on_finish(<Future finis...sertions.\n')>) at /home/hadoop-platcv/.local/lib/python3.9/site-packages/lmdeploy/vl/engine.py:20>
Traceback (most recent call last):
  File "/usr/local/lib/python3.9/asyncio/events.py", line 80, in _run
    self._context.run(self._callback, *self._args)
  File "/home/hadoop-platcv/.local/lib/python3.9/site-packages/lmdeploy/vl/engine.py", line 27, in _raise_exception_on_finish
    raise e
  File "/home/hadoop-platcv/.local/lib/python3.9/site-packages/lmdeploy/vl/engine.py", line 23, in _raise_exception_on_finish
    task.result()
  File "/usr/local/lib/python3.9/concurrent/futures/thread.py", line 58, in run
    result = self.fn(*self.args, **self.kwargs)
  File "/home/hadoop-platcv/.local/lib/python3.9/site-packages/lmdeploy/vl/engine.py", line 169, in forward
    outputs = self.model.forward(*func_inputs)
  File "/usr/local/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/hadoop-platcv/.local/lib/python3.9/site-packages/lmdeploy/vl/model/internvl.py", line 186, in forward
    return self._forward_func(images, params)
  File "/home/hadoop-platcv/.local/lib/python3.9/site-packages/lmdeploy/vl/model/internvl.py", line 165, in _forward_v1_5
    outputs = self.model.extract_feature(outputs)
  File "/home/hadoop-platcv/.cache/huggingface/modules/transformers_modules/InternVL2-1B/modeling_internvl_chat.py", line 183, in extract_feature
    vit_embeds = self.vision_model(
  File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/hadoop-platcv/.cache/huggingface/modules/transformers_modules/InternVL2-1B/modeling_intern_vit.py", line 413, in forward
    encoder_outputs = self.encoder(
  File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/hadoop-platcv/.cache/huggingface/modules/transformers_modules/InternVL2-1B/modeling_intern_vit.py", line 348, in forward
    layer_outputs = encoder_layer(
  File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/hadoop-platcv/.cache/huggingface/modules/transformers_modules/InternVL2-1B/modeling_intern_vit.py", line 292, in forward
    hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states).to(hidden_states.dtype)) * self.ls2)
  File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/hadoop-platcv/.cache/huggingface/modules/transformers_modules/InternVL2-1B/modeling_intern_vit.py", line 259, in forward
    hidden_states = self.fc1(hidden_states)
  File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 116, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: CUDA error: operation not permitted when stream is capturing
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

2024-10-02 10:40:24,380 - lmdeploy - ERROR - Engine loop failed with error: CUDA error: operation failed due to a previous error during capture
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Traceback (most recent call last):
  File "/home/hadoop-platcv/.local/lib/python3.9/site-packages/lmdeploy/pytorch/backends/cuda/graph_runner.py", line 181, in capture
    output = self.model(**padded_kwargs)
  File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/hadoop-platcv/.local/lib/python3.9/site-packages/lmdeploy/pytorch/models/internvl.py", line 39, in forward
    return self.language_model.forward(input_ids=input_ids,
  File "/home/hadoop-platcv/.local/lib/python3.9/site-packages/lmdeploy/pytorch/models/qwen2.py", line 340, in forward
    hidden_states = self.model(
  File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/hadoop-platcv/.local/lib/python3.9/site-packages/lmdeploy/pytorch/models/qwen2.py", line 270, in forward
    cos, sin = self.rotary_emb(hidden_states, position_ids)
  File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/hadoop-platcv/.local/lib/python3.9/site-packages/lmdeploy/pytorch/backends/default/rotary_embedding.py", line 67, in forward
    return _rotary_embedding_fwd(position_ids,
  File "/home/hadoop-platcv/.local/lib/python3.9/site-packages/lmdeploy/pytorch/backends/default/rotary_embedding.py", line 33, in _rotary_embedding_fwd
    freqs = (inv_freq_expanded.float()
RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `cublasSgemm( handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)`

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/hadoop-platcv/.local/lib/python3.9/site-packages/lmdeploy/pytorch/engine/request.py", line 17, in _raise_exception_on_finish
    task.result()
  File "/home/hadoop-platcv/.local/lib/python3.9/site-packages/lmdeploy/pytorch/engine/engine.py", line 941, in async_loop
    await self._async_loop()
  File "/home/hadoop-platcv/.local/lib/python3.9/site-packages/lmdeploy/pytorch/engine/engine.py", line 935, in _async_loop
    await __step(False)
  File "/home/hadoop-platcv/.local/lib/python3.9/site-packages/lmdeploy/pytorch/engine/engine.py", line 917, in __step
    raise e
  File "/home/hadoop-platcv/.local/lib/python3.9/site-packages/lmdeploy/pytorch/engine/engine.py", line 909, in __step
    raise out
  File "/home/hadoop-platcv/.local/lib/python3.9/site-packages/lmdeploy/pytorch/engine/engine.py", line 853, in _async_loop_background
    await self._async_step_background(
  File "/home/hadoop-platcv/.local/lib/python3.9/site-packages/lmdeploy/pytorch/engine/engine.py", line 732, in _async_step_background
    output = await self._async_model_forward(
  File "/home/hadoop-platcv/.local/lib/python3.9/site-packages/lmdeploy/utils.py", line 237, in __tmp
    return (await func(*args, **kwargs))
  File "/home/hadoop-platcv/.local/lib/python3.9/site-packages/lmdeploy/pytorch/engine/engine.py", line 630, in _async_model_forward
    ret = await __forward(inputs)
  File "/home/hadoop-platcv/.local/lib/python3.9/site-packages/lmdeploy/pytorch/engine/engine.py", line 608, in __forward
    return await self.model_agent.async_forward(
  File "/home/hadoop-platcv/.local/lib/python3.9/site-packages/lmdeploy/pytorch/engine/model_agent.py", line 332, in async_forward
    output = self._forward_impl(inputs,
  File "/home/hadoop-platcv/.local/lib/python3.9/site-packages/lmdeploy/pytorch/engine/model_agent.py", line 299, in _forward_impl
    output = model_forward(
  File "/usr/local/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/hadoop-platcv/.local/lib/python3.9/site-packages/lmdeploy/pytorch/engine/model_agent.py", line 154, in model_forward
    output = model(**input_dict)
  File "/home/hadoop-platcv/.local/lib/python3.9/site-packages/lmdeploy/pytorch/backends/cuda/graph_runner.py", line 257, in __call__
    runner.capture(**kwargs)
  File "/home/hadoop-platcv/.local/lib/python3.9/site-packages/lmdeploy/pytorch/backends/cuda/graph_runner.py", line 181, in capture
    output = self.model(**padded_kwargs)
  File "/usr/local/lib/python3.9/site-packages/torch/cuda/graphs.py", line 183, in __exit__
    self.cuda_graph.capture_end()
  File "/usr/local/lib/python3.9/site-packages/torch/cuda/graphs.py", line 81, in capture_end
    super().capture_end()
RuntimeError: CUDA error: operation failed due to a previous error during capture
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

LinJianping avatar Oct 02 '24 04:10 LinJianping

It cannot be reproduced with the latest main branch.

grimoire avatar Oct 08 '24 02:10 grimoire

I started getting this error with PyTorch engine in the latest release for Qwen2-VL model. I get the error with batch size >= 6. When batch size is 1, everything runs fine.

visheratin-pai avatar Oct 08 '24 04:10 visheratin-pai

I still can not reproduce the error.

Since there is a sgemm cublas error report, try replace

https://github.com/InternLM/lmdeploy/blob/2e49fc33916dc4a9feb63d4cd57b6be862000f93/lmdeploy/pytorch/backends/default/rotary_embedding.py#L33-L34

with

        freqs = (inv_freq_expanded.float()
                 * position_ids_expanded.float()).transpose(1, 2)

grimoire avatar Oct 08 '24 06:10 grimoire

It cannot be reproduced with the latest main branch.

When batch size is 1, everything runs fine. When batch_size is set to 8, the above error occasionally occurs during loop execution. I think it may have something to do with the cuda version. My cuda version is 12.2. When I use the default lmdeploy-0.6.1-cp39-cp39-manylinux2014_x86_64.whl, the probability of anomalies is very high, but when I use lmdeploy-0.6.1+cu118-cp39-cp39-manylinux2014_x86_64.whl, the probability of anomalies is reduced.

LinJianping avatar Oct 08 '24 06:10 LinJianping

I have asked an expert, the error might come from the vision model on the default stream. Which would corruption the capturing of language model in the other stream. I will try fix it ASAP.

grimoire avatar Oct 08 '24 06:10 grimoire

I have asked an expert, the error might come from the vision model on the default stream. Which would corruption the capturing of language model in the other stream. I will try fix it ASAP.

Another question is, when I use triton python backend to deploy, and set dynamic batching, is it also easy to cause exceptions due to cuda graph capture of different batch_sizes?

LinJianping avatar Oct 08 '24 07:10 LinJianping

We would capture multiple graphs with different input sizes, and the input would be padded to the capture size before forward. It is safe to use dynamic batching.

grimoire avatar Oct 08 '24 07:10 grimoire

We would capture multiple graphs with different input sizes, and the input would be padded to the capture size before forward. It is safe to use dynamic batching.

What is the specific capture strategy like? For example, the default capture batch size options may be 1, 2, 4, 8, etc. In this way, I can set the corresponding prefer batch size to obtain the best inference performance.

LinJianping avatar Oct 08 '24 07:10 LinJianping

Another curious question is why TurboMind supports the 2B-76B InternVL2 model but not the 1B model. Are there any plans to support it in the future? @grimoire

LinJianping avatar Oct 08 '24 08:10 LinJianping

https://github.com/grimoire/lmdeploy/tree/fix-vl-graphcapture I have set the capture mode to thread_local, which might fix the bug.

What is the specific capture strategy like?

https://github.com/grimoire/lmdeploy/blob/e16c49170f1413f23c03cac2d3549ca7b7f711c4/lmdeploy/pytorch/backends/cuda/graph_runner.py#L133

The engine would generate graphs with token numbers [1, 2, 4,..., 256], you don't have to care much about that since pytorch engine would schedule the requests to the best batch size.

why TurboMind supports the 2B-76B InternVL2 model but not the 1B model

Intervl2-1b use qwen2-0.5b as it's language model, which has head_size=64. Turbomind does not support head_size<=128.

grimoire avatar Oct 08 '24 09:10 grimoire

https://github.com/grimoire/lmdeploy/tree/fix-vl-graphcapture I have set the capture mode to thread_local, which might fix the bug.

Seems like it is branched not from the latest version but from 0.4.2. When installing, it downgraded pytorch and I get this error for Qwen2-VL: "Unrecognized configuration class <class 'transformers.models.qwen2_vl.configuration_qwen2_vl.Qwen2VLConfig'> for this kind of AutoModel".

visheratin-pai avatar Oct 08 '24 11:10 visheratin-pai

Are you using the main branch of my repo? I have create a draft PR https://github.com/InternLM/lmdeploy/pull/2560, Please try this.

grimoire avatar Oct 08 '24 11:10 grimoire

Sorry, forgot to switch branches! Yes, the issue doesn't occur when using the correct branch.

visheratin-pai avatar Oct 08 '24 12:10 visheratin-pai

https://github.com/grimoire/lmdeploy/tree/fix-vl-graphcapture I have set the capture mode to thread_local, which might fix the bug.

What is the specific capture strategy like?

https://github.com/grimoire/lmdeploy/blob/e16c49170f1413f23c03cac2d3549ca7b7f711c4/lmdeploy/pytorch/backends/cuda/graph_runner.py#L133

The engine would generate graphs with token numbers [1, 2, 4,..., 256], you don't have to care much about that since pytorch engine would schedule the requests to the best batch size.

why TurboMind supports the 2B-76B InternVL2 model but not the 1B model

Intervl2-1b use qwen2-0.5b as it's language model, which has head_size=64. Turbomind does not support head_size<=128.

This method works well for me. Also, I'm curious to know if there's any plan to bring TurboMind support to smaller models like Intervl2-1b, or is the workload too heavy to make it happen in the near future?

LinJianping avatar Oct 09 '24 02:10 LinJianping

I'm curious to know if there's any plan to bring TurboMind support to smaller models like Intervl2-1b

@lvhan028 @lzhangzz

grimoire avatar Oct 09 '24 02:10 grimoire

I have asked an expert, the error might come from the vision model on the default stream. Which would corruption the capturing of language model in the other stream. I will try fix it ASAP.

Hi! I encountered the same problem. When using Qwen2-VL-7B, it works fine when batch_size is 1, but the same error is reported when it is set to 4. I used the latest version of lmdeploy. Has this problem been fixed now?

sudanl avatar Oct 19 '24 09:10 sudanl