lmdeploy icon indicating copy to clipboard operation
lmdeploy copied to clipboard

[Bug] Qwen2-VL on NVIDIA L20 fails with Triton shared memory OutOfResources error

Open Kai-dev7 opened this issue 3 months ago • 4 comments

Checklist

  • [x] 1. I have searched related issues but cannot get the expected help.
  • [x] 2. The bug has not been fixed in the latest version.
  • [x] 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.

Describe the bug

Qwen2-VL on NVIDIA L20 fails with Triton shared memory OutOfResources error,Qwen2-vl-2B

Reproduction

POST /v1/chat/completions HTTP/1.0

Environment

sys.platform: linux

Python: 3.10.12 (main, Feb  4 2025, 14:57:36) [GCC 11.4.0]

CUDA available: True

MUSA available: False

numpy_random_seed: 2147483648

GPU 0: NVIDIA L20

CUDA_HOME: /usr/local/cuda

NVCC: Cuda compilation tools, release 12.4, V12.4.131

GCC: x86_64-linux-gnu-gcc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0

PyTorch: 2.6.0+cu124

PyTorch compiling details: PyTorch built with:

  - GCC 9.3

  - C++ Version: 201703

  - Intel(R) oneAPI Math Kernel Library Version 2024.2-Product Build 20240605 for Intel(R) 64 architecture applications

  - Intel(R) MKL-DNN v3.5.3 (Git Hash 66f0cb9eb66affd2da3bf5f8d897376f04aae6af)

  - OpenMP 201511 (a.k.a. OpenMP 4.5)

  - LAPACK is enabled (usually provided by MKL)

  - NNPACK is enabled

  - CPU capability usage: AVX512

  - CUDA Runtime 12.4

  - NVCC architecture flags: -gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_90,code=sm_90

  - CuDNN 90.1

  - Magma 2.6.1

  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, COMMIT_SHA=2236df1770800ffea5697b11b0bb0d910b2e59e1, CUDA_VERSION=12.4, CUDNN_VERSION=9.1.0, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DLIBKINETO_NOXPUPTI=ON -DUSE_FBGEMM -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wsuggest-override -Wno-psabi -Wno-error=old-style-cast -Wno-missing-braces -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, TORCH_VERSION=2.6.0, USE_CUDA=ON, USE_CUDNN=ON, USE_CUSPARSELT=1, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_GLOO=ON, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=1, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, USE_ROCM_KERNEL_ASSERT=OFF,





TorchVision: 0.21.0+cu124



LMDeploy: 0.8.0+

transformers: 4.50.0

gradio: 5.29.0

fastapi: 0.115.12

pydantic: 2.11.4

triton: 3.2.0

NVIDIA Topology:

        GPU0    CPU Affinity    NUMA Affinity   GPU NUMA ID

GPU0     X      0-47,96-143     0               N/A





Legend:







  X    = Self



  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)

  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node

  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)

  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)

  PIX  = Connection traversing at most a single PCIe bridge

  NV#  = Connection traversing a bonded set of # NVLinks

Error traceback

ret = await __forward(inputs)

  File "/opt/lmdeploy/lmdeploy/pytorch/engine/model_agent.py", line 249, in __forward

    return await self.async_forward(inputs, swap_in_map=swap_in_map, swap_out_map=swap_out_map)

  File "/opt/lmdeploy/lmdeploy/pytorch/engine/model_agent.py", line 660, in async_forward

    output = self._forward_impl(inputs, swap_in_map=swap_in_map, swap_out_map=swap_out_map)

  File "/opt/lmdeploy/lmdeploy/pytorch/engine/model_agent.py", line 644, in _forward_impl

    output = model_forward(

  File "/opt/py3/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context

    return func(*args, **kwargs)

  File "/opt/lmdeploy/lmdeploy/pytorch/engine/model_agent.py", line 73, in model_forward

    output = model(**input_dict)

  File "/opt/lmdeploy/lmdeploy/pytorch/backends/cuda/graph_runner.py", line 161, in __call__

    return self.model(**kwargs)

  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl

    return self._call_impl(*args, **kwargs)

  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl

    return forward_call(*args, **kwargs)

  File "/opt/lmdeploy/lmdeploy/pytorch/models/qwen2_vl.py", line 666, in forward

    image_embeds = self.visual(pixel_values, cu_seqlens=vis_cu_seqlens, rotary_pos_emb=vis_pos_emb)

  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl

    return self._call_impl(*args, **kwargs)

  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl

    return forward_call(*args, **kwargs)

  File "/opt/lmdeploy/lmdeploy/pytorch/models/qwen2_vl.py", line 593, in forward

    hidden_states, residual = blk(hidden_states,

  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl

    return self._call_impl(*args, **kwargs)

  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl

    return forward_call(*args, **kwargs)

  File "/opt/lmdeploy/lmdeploy/pytorch/models/qwen2_vl.py", line 497, in forward

    hidden_states = self.attn(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)

  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl

    return self._call_impl(*args, **kwargs)

  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl

    return forward_call(*args, **kwargs)

  File "/opt/lmdeploy/lmdeploy/pytorch/models/qwen2_vl.py", line 414, in forward

    attn_output = self.attention(

  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl

    return self._call_impl(*args, **kwargs)

  File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl

    return forward_call(*args, **kwargs)

  File "/opt/lmdeploy/lmdeploy/pytorch/nn/attention.py", line 154, in forward

    return self.impl.forward(

  File "/opt/lmdeploy/lmdeploy/pytorch/backends/cuda/flash_attention.py", line 56, in forward

    self.flash_attention_fwd(

  File "/opt/lmdeploy/lmdeploy/pytorch/kernels/cuda/flashattention.py", line 449, in flash_attention_fwd

    _flash_prefill_fwd_kernel[grid](

  File "/opt/py3/lib/python3.10/site-packages/triton/runtime/jit.py", line 330, in <lambda>

    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)

  File "/opt/py3/lib/python3.10/site-packages/triton/runtime/jit.py", line 653, in run

    kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,

  File "/opt/py3/lib/python3.10/site-packages/triton/compiler/compiler.py", line 395, in __getattribute__

    self._init_handles()

  File "/opt/py3/lib/python3.10/site-packages/triton/compiler/compiler.py", line 388, in _init_handles

    raise OutOfResources(self.metadata.shared, max_shared, "shared memory")

triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 126976, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.

2025-09-01 22:40:26,366 - lmdeploy - ERROR - async_engine.py:791 - session 1 finished, reason "error"

Kai-dev7 avatar Sep 02 '25 06:09 Kai-dev7

"triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 126976, Hardware limit: 101376. Reducing block sizes or num_stages may help." It means triton.jit needs more resources. Maybe you can open the comment here and try tuning at runtime

@grimoire any suggestions?

lvhan028 avatar Sep 04 '25 12:09 lvhan028

https://github.com/InternLM/lmdeploy/blob/967df47f574056740cb45b52338563373730c144/lmdeploy/pytorch/kernels/cuda/flashattention.py#L498 try manually tunning these arguements.

grimoire avatar Sep 04 '25 13:09 grimoire

Facing the same problem with Qwen2.5-VL 72B on 4*L20. I'm not quite familiar with these parameters, any suggestions or available resources on how to tune these parameters? Much appreciated.

zodiacg avatar Sep 11 '25 09:09 zodiacg

@cuikaiGitHub 0.8.0 is a quite old version, try switch to our latest release.

If latest release still does not works. Manually tuning values above might works. num_stages is an int scalar between [1, ~) and BLOCK_M/BLOCKN should be [16, ~) and power of 2. Small value means small smem usage.

grimoire avatar Sep 11 '25 10:09 grimoire