lmdeploy
lmdeploy copied to clipboard
[Bug] Qwen/Qwen2-1.5B error: floating point exception
Checklist
- [X] 1. I have searched related issues but cannot get the expected help.
- [X] 2. The bug has not been fixed in the latest version.
- [X] 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
Describe the bug
在2080ti上运行时出错floating point exception (core dumped)
Reproduction
lmdeploy serve api_server Qwen/Qwen2-1.5B --server-port=8000 --tp=4 --model-name=default-model --max-batch-size=32 --session-len=32768 --log-level INFO
Environment
sys.platform: linux
Python: 3.11.9 (main, Jun 23 2024, 16:24:59) [GCC 14.1.1 20240522]
CUDA available: True
MUSA available: False
numpy_random_seed: 2147483648
GPU 0,1,2,3: NVIDIA GeForce RTX 2080 Ti
CUDA_HOME: /opt/cuda
NVCC: Cuda compilation tools, release 12.1, V12.1.66
GCC: gcc (GCC) 14.1.1 20240522
PyTorch: 2.2.2+cu121
PyTorch compiling details: PyTorch built with:
- GCC 9.3
- C++ Version: 201703
- Intel(R) oneAPI Math Kernel Library Version 2022.2-Product Build 20220804 for Intel(R) 64 architecture applications
- Intel(R) MKL-DNN v3.3.2 (Git Hash 2dc95a2ad0841e29db8b22fbccaf3e5da7992b01)
- OpenMP 201511 (a.k.a. OpenMP 4.5)
- LAPACK is enabled (usually provided by MKL)
- NNPACK is enabled
- CPU capability usage: AVX2
- CUDA Runtime 12.1
- NVCC architecture flags: -gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_90,code=sm_90
- CuDNN 8.9.2
- Magma 2.6.1
- Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=12.1, CUDNN_VERSION=8.9.2, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wsuggest-override -Wno-psabi -Wno-error=pedantic -Wno-error=old-style-cast -Wno-missing-braces -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=2.2.2, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=1, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, USE_ROCM_KERNEL_ASSERT=OFF,
TorchVision: 0.17.2+cu121
LMDeploy: 0.5.2.post1+1fcdc42
transformers: 4.43.4
gradio: Not Found
fastapi: 0.111.1
pydantic: 2.8.2
triton: 2.2.0
NVIDIA Topology:
GPU0 GPU1 GPU2 GPU3 CPU Affinity NUMA Affinity GPU NUMA ID
GPU0 X PHB SYS SYS 0-11,24-35 0 N/A
GPU1 PHB X SYS SYS 0-11,24-35 0 N/A
GPU2 SYS SYS X PHB 12-23,36-47 1 N/A
GPU3 SYS SYS PHB X 12-23,36-47 1 N/A
Legend:
X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks
Error traceback
Fetching 10 files: 100%|████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 51087.75it/s]
2024-08-08 20:02:51,328 - lmdeploy - INFO - input backend=turbomind, backend_config=TurbomindEngineConfig(model_name='default-model', model_format=None, tp=4, session_len=32768, max_batch_size=32, cache_max_entry_count=0.8, cache_block_seq_len=64, enable_prefix_caching=False, quant_policy=0, rope_scaling_factor=0.0, use_logn_attn=False, download_dir=None, revision=None, max_prefill_token_num=8192, num_tokens_per_iter=0, max_prefill_iters=1)
2024-08-08 20:02:51,328 - lmdeploy - INFO - input chat_template_config=None
2024-08-08 20:02:51,622 - lmdeploy - INFO - updated chat_template_onfig=ChatTemplateConfig(model_name='qwen', system=None, meta_instruction=None, eosys=None, user=None, eoh=None, assistant=None, eoa=None, separator=None, capability=None, stop_words=None)
2024-08-08 20:02:51,623 - lmdeploy - INFO - model_source: ModelSource.HF_MODEL
2024-08-08 20:02:51,623 - lmdeploy - WARNING - model_name is deprecated in TurbomindEngineConfig and has no effect
Device does not support bfloat16. Set float16 forcefully
2024-08-08 20:02:51,911 - lmdeploy - INFO - model_config:
[llama]
model_name = qwen
model_arch = Qwen2ForCausalLM
tensor_para_size = 4
head_num = 12
kv_head_num = 2
vocab_size = 151936
num_layer = 28
inter_size = 8960
norm_eps = 1e-06
attn_bias = 1
start_id = 151643
end_id = 151645
session_len = 32768
weight_type = fp16
rotary_embedding = 128
rope_theta = 1000000.0
size_per_head = 128
group_size = 0
max_batch_size = 32
max_context_token_num = 1
step_length = 1
cache_max_entry_count = 0.8
cache_block_seq_len = 64
cache_chunk_size = -1
enable_prefix_caching = False
num_tokens_per_iter = 8192
max_prefill_iters = 4
extra_tokens_per_iter = 0
use_context_fmha = 1
quant_policy = 0
max_position_embeddings = 131072
original_max_position_embeddings = 0
rope_scaling_type =
rope_scaling_factor = 0.0
use_dynamic_ntk = 0
low_freq_factor = 1.0
high_freq_factor = 1.0
use_logn_attn = 0
lora_policy =
lora_r = 0
lora_scale = 0.0
lora_max_wo_r = 0
lora_rank_pattern =
lora_scale_pattern =
[TM][WARNING] [LlamaTritonModel] `max_context_token_num` = 32768.
2024-08-08 20:02:53,824 - lmdeploy - WARNING - get 843 model params
2024-08-08 20:02:55,738 - lmdeploy - INFO - updated backend_config=TurbomindEngineConfig(model_name='default-model', model_format=None, tp=4, session_len=32768, max_batch_size=32, cache_max_entry_count=0.8, cache_block_seq_len=64, enable_prefix_caching=False, quant_policy=0, rope_scaling_factor=0.0, use_logn_attn=False, download_dir=None, revision=None, max_prefill_token_num=8192, num_tokens_per_iter=0, max_prefill_iters=1)
[WARNING] gemm_config.in is not found; using default GEMM algo
[TM][INFO] NCCL group_id = 0
[WARNING] gemm_config.in is not found; using default GEMM algo
[TM][INFO] NCCL group_id = 0
[WARNING] gemm_config.in is not found; using default GEMM algo
[TM][INFO] NCCL group_id = 0
[WARNING] gemm_config.in is not found; using default GEMM algo
[TM][INFO] NCCL group_id = 0
zsh: floating point exception (core dumped) lmdeploy serve api_server Qwen/Qwen2-1.5B --server-port=8000 --tp=4 INFO
之前有反馈说更新cublas版本就好了。可以试一下 pip3 install nvidia-cublas-cu12==12.3.4.1
@irexyc 我升级torch从2.2.2到2.3.1,也试了pip3 install nvidia-cublas-cu12==12.3.4.1,但还是一样的错误。用gdb拉了一下堆栈:
(gdb) backtrace
#0 0x00007fc857b6aa58 in turbomind::BlockManager::GetBlockCount(unsigned long, double, std::function<unsigned long ()>) (
block_size=block_size@entry=0, ratio=ratio@entry=0.80000001192092896, get_free_size=...)
at /lmdeploy/src/turbomind/models/llama/BlockManager.cc:104
#1 0x00007fc857b6bf7b in turbomind::BlockManager::BlockManager(unsigned long, double, int, turbomind::IAllocator*, std::function<unsigned long ()>) (this=0x7fc7d3fa1740, block_size=0, block_count=0.80000001192092896, chunk_size=-1, allocator=<optimized out>, get_free_size=...)
at /lmdeploy/src/turbomind/models/llama/BlockManager.cc:35
#2 0x00007fc857b6f9c7 in __gnu_cxx::new_allocator<turbomind::BlockManager>::construct<turbomind::BlockManager, unsigned long&, double&, int&, turbomind::IAllocator*&, std::function<unsigned long ()>&>(turbomind::BlockManager*, unsigned long&, double&, int&, turbomind::IAllocator*&, std::function<unsigned long ()>&) (__p=0x7fc7d3fa1740, this=<optimized out>) at /opt/rh/devtoolset-9/root/usr/include/c++/9/new:174
#3 std::allocator_traits<std::allocator<turbomind::BlockManager> >::construct<turbomind::BlockManager, unsigned long&, double&, int&, turbomind::IAllocator*&, std::function<unsigned long ()>&>(std::allocator<turbomind::BlockManager>&, turbomind::BlockManager*, unsigned long&, double&, int&, turbomind::IAllocator*&, std::function<unsigned long ()>&) (__p=0x7fc7d3fa1740, __a=...)
at /opt/rh/devtoolset-9/root/usr/include/c++/9/bits/alloc_traits.h:484
#4 std::_Sp_counted_ptr_inplace<turbomind::BlockManager, std::allocator<turbomind::BlockManager>, (__gnu_cxx::_Lock_policy)2>::_Sp_counted_ptr_inplace<unsigned long&, double&, int&, turbomind::IAllocator*&, std::function<unsigned long ()>&>(std::allocator<turbomind::BlockManager>, unsigned long&, double&, int&, turbomind::IAllocator*&, std::function<unsigned long ()>&) (__a=..., this=0x7fc7d3fa1730)
at /opt/rh/devtoolset-9/root/usr/include/c++/9/bits/shared_ptr_base.h:548
#5 std::__shared_count<(__gnu_cxx::_Lock_policy)2>::__shared_count<turbomind::BlockManager, std::allocator<turbomind::BlockManager>, unsigned long&, double&, int&, turbomind::IAllocator*&, std::function<unsigned long ()>&>(turbomind::BlockManager*&, std::_Sp_alloc_shared_tag<std::allocator<turbomind::BlockManager> >, unsigned long&, double&, int&, turbomind::IAllocator*&, std::function<unsigned long ()>&) (__a=...,
__p=<optimized out>, this=<optimized out>) at /opt/rh/devtoolset-9/root/usr/include/c++/9/bits/shared_ptr_base.h:679
#6 std::__shared_ptr<turbomind::BlockManager, (__gnu_cxx::_Lock_policy)2>::__shared_ptr<std::allocator<turbomind::BlockManager>, unsigned long&, double&, int&, turbomind::IAllocator*&, std::function<unsigned long ()>&>(std::_Sp_alloc_shared_tag<std::allocator<turbomind::BlockManager> >, unsigned long&, double&, int&, turbomind::IAllocator*&, std::function<unsigned long ()>&) (__tag=..., this=<optimized out>)
at /opt/rh/devtoolset-9/root/usr/include/c++/9/bits/shared_ptr_base.h:1344
#7 std::shared_ptr<turbomind::BlockManager>::shared_ptr<std::allocator<turbomind::BlockManager>, unsigned long&, double&, int&, turbomind::IAllocator*&, std::function<unsigned long ()>&>(std::_Sp_alloc_shared_tag<std::allocator<turbomind::BlockManager> >, unsigned long&, double&, int&, turbomind::IAllocator*&, std::function<unsigned long ()>&) (__tag=..., this=<optimized out>)
at /opt/rh/devtoolset-9/root/usr/include/c++/9/bits/shared_ptr.h:359
#8 std::allocate_shared<turbomind::BlockManager, std::allocator<turbomind::BlockManager>, unsigned long&, double&, int&, turbomind::IAllocator*&, std::function<unsigned long ()>&>(std::allocator<turbomind::BlockManager> const&, unsigned long&, double&, int&, turbomind::IAllocator*&, std::func--Type <RET> for more, q to quit, c to continue without paging--
tion<unsigned long ()>&) (__a=...) at /opt/rh/devtoolset-9/root/usr/include/c++/9/bits/shared_ptr.h:702
#9 std::make_shared<turbomind::BlockManager, unsigned long&, double&, int&, turbomind::IAllocator*&, std::function<unsigned long ()>&>(unsigned long&, double&, int&, turbomind::IAllocator*&, std::function<unsigned long ()>&) ()
at /opt/rh/devtoolset-9/root/usr/include/c++/9/bits/shared_ptr.h:718
#10 turbomind::SequenceManager::SequenceManager(unsigned long, turbomind::SequenceManager::BlockConfig const&, double, int, bool, int, turbomind::IAllocator*, std::function<unsigned long ()>) (this=0x7fc7d3904900, layer_num=<optimized out>, block_config=..., block_count=0.80000001192092896,
chunk_size=-1, enable_prefix_caching=<optimized out>, rank=<optimized out>, allocator=0x7fc7d38d9540, get_free_size=...)
at /lmdeploy/src/turbomind/models/llama/SequenceManager.cc:32
#11 0x00007fc857b61019 in turbomind::LlamaBatch<__half>::LlamaBatch (this=0x7fc7d3a0cc30, params=..., cache_block_seq_len=64,
quant_policy=<optimized out>, model=0x7fc7d3a13a30) at /lmdeploy/src/turbomind/models/llama/LlamaBatch.cc:965
#12 0x00007fc857b2ec3c in std::make_unique<turbomind::LlamaBatch<__half>, turbomind::EngineParams const&, int&, int&, turbomind::LlamaV2<__half>*>
() at /opt/rh/devtoolset-9/root/usr/include/c++/9/bits/unique_ptr.h:856
#13 turbomind::LlamaV2<__half>::LlamaV2 (this=0x7fc7d3a13a30, head_num=32, kv_head_num=2, size_per_head=<optimized out>,
inter_size=<optimized out>, num_layer=40, vocab_size=151552, norm_eps=<optimized out>, attn_params=..., start_id=0, end_id=151329,
cache_block_seq_len=64, quant_policy=0, use_context_fmha=true, engine_params=..., lora_params=...,
shared_state=std::shared_ptr<turbomind::LlamaV2<__half>::SharedState> (use count 9, weak count 0) = {...}, weights=0x7fc7c029c520,
tensor_para=..., stream=0x7fc7d0000bd0, cublas_wrapper=0x7fc7d3d6df10, allocator=0x7fc7d38d9540, peer_alloctor=0x7fc7d021b5a0,
is_free_buffer_after_forward=false, cuda_device_prop=0x7fc7d3a13620) at /lmdeploy/src/turbomind/models/llama/LlamaV2.cc:103
#14 0x00007fc857af458a in std::make_unique<turbomind::LlamaV2<__half>, unsigned long&, unsigned long&, unsigned long&, unsigned long&, unsigned long&, unsigned long&, float&, turbomind::LlamaAttentionParams&, int&, int&, int&, int&, int&, turbomind::EngineParams&, turbomind::LoraParams&, std::shared_ptr<turbomind::LlamaV2<__half>::SharedState>&, turbomind::LlamaWeight<__half>*, turbomind::NcclParam&, CUstream_st*&, turbomind::cublasMMWrapper*, turbomind::Allocator<(turbomind::AllocatorType)0>*, turbomind::Allocator<(turbomind::AllocatorType)0>*, bool, cudaDeviceProp*> ()
at /opt/rh/devtoolset-9/root/usr/include/c++/9/bits/unique_ptr.h:857
#15 LlamaTritonModel<__half>::createSharedModelInstance (this=this@entry=0x561075ae9710, device_id=device_id@entry=1, rank=rank@entry=1,
nccl_params={...}, custom_all_reduce_comm=std::shared_ptr<turbomind::AbstractCustomComm> (empty) = {...})
at /lmdeploy/src/turbomind/triton_backend/llama/LlamaTritonModel.cc:336
#16 0x00007fc857af8b32 in LlamaTritonModel<__half>::createModelInstance (this=this@entry=0x561075ae9710, device_id=device_id@entry=1,
rank=rank@entry=1, stream=stream@entry=0x0, nccl_params={...},
custom_all_reduce_comm=std::shared_ptr<turbomind::AbstractCustomComm> (empty) = {...})
at /lmdeploy/src/turbomind/triton_backend/llama/LlamaTritonModel.cc:389
#17 0x00007fc857ab8b3f in <lambda(AbstractTransformerModel*, int, int, long int, std::pair<std::vector<turbomind::NcclParam, std::allocator<turbomi--Type <RET> for more, q to quit, c to continue without paging--
nd::NcclParam> >, std::vector<turbomind::NcclParam, std::allocator<turbomind::NcclParam> > >, std::shared_ptr<turbomind::AbstractCustomComm>)>::operator() (__closure=<optimized out>, custom_all_reduce_comm=..., nccl_params={...}, stream_id=0, rank=1, deviceId=1, model=0x561075ae9710)
at /lmdeploy/src/turbomind/python/bind.cpp:431
#18 pybind11::detail::argument_loader<AbstractTransformerModel*, int, int, long, std::pair<std::vector<turbomind::NcclParam, std::allocator<turbomind::NcclParam> >, std::vector<turbomind::NcclParam, std::allocator<turbomind::NcclParam> > >, std::shared_ptr<turbomind::AbstractCustomComm> >::call_impl<std::unique_ptr<AbstractTransformerModelInstance>, pybind11_init__turbomind(pybind11::module_&)::<lambda(AbstractTransformerModel*, int, int, long int, std::pair<std::vector<turbomind::NcclParam>, std::vector<turbomind::NcclParam> >, std::shared_ptr<turbomind::AbstractCustomComm>)>&, 0, 1, 2, 3, 4, 5, pybind11::gil_scoped_release> (f=..., this=0x7fc7e67dc580)
at /opt/conda/envs/py311/lib/python3.11/site-packages/pybind11/include/pybind11/cast.h:1613
#19 pybind11::detail::argument_loader<AbstractTransformerModel*, int, int, long, std::pair<std::vector<turbomind::NcclParam, std::allocator<turbomind::NcclParam> >, std::vector<turbomind::NcclParam, std::allocator<turbomind::NcclParam> > >, std::shared_ptr<turbomind::AbstractCustomComm> >::call<std::unique_ptr<AbstractTransformerModelInstance>, pybind11::gil_scoped_release, pybind11_init__turbomind(pybind11::module_&)::<lambda(AbstractTransformerModel*, int, int, long int, std::pair<std::vector<turbomind::NcclParam>, std::vector<turbomind::NcclParam> >, std::shared_ptr<turbomind::AbstractCustomComm>)>&> (f=..., this=0x7fc7e67dc580) at /opt/conda/envs/py311/lib/python3.11/site-packages/pybind11/include/pybind11/cast.h:1582
#20 <lambda(pybind11::detail::function_call&)>::operator() (this=0x0, call=...)
at /opt/conda/envs/py311/lib/python3.11/site-packages/pybind11/include/pybind11/pybind11.h:296
#21 <lambda(pybind11::detail::function_call&)>::_FUN(pybind11::detail::function_call &) ()
at /opt/conda/envs/py311/lib/python3.11/site-packages/pybind11/include/pybind11/pybind11.h:267
#22 0x00007fc857ad66b2 in pybind11::cpp_function::dispatcher (self=<optimized out>, args_in=0x7fc861b521b0, kwargs_in=0x0)
at /opt/conda/envs/py311/lib/python3.11/site-packages/pybind11/include/pybind11/pybind11.h:987
#23 0x00007fc98f5d85c3 in ?? () from /usr/lib/libpython3.11.so.1.0
#24 0x00007fc98f5450f8 in _PyObject_MakeTpCall () from /usr/lib/libpython3.11.so.1.0
#25 0x00007fc98f6df81b in _PyEval_EvalFrameDefault () from /usr/lib/libpython3.11.so.1.0
#26 0x00007fc98f6e72f0 in ?? () from /usr/lib/libpython3.11.so.1.0
#27 0x00007fc98f663b76 in ?? () from /usr/lib/libpython3.11.so.1.0
#28 0x00007fc98f6e1fc9 in _PyEval_EvalFrameDefault () from /usr/lib/libpython3.11.so.1.0
#29 0x00007fc98f6e72f0 in ?? () from /usr/lib/libpython3.11.so.1.0
#30 0x00007fc98f6e1fc9 in _PyEval_EvalFrameDefault () from /usr/lib/libpython3.11.so.1.0
#31 0x00007fc98f6e72f0 in ?? () from /usr/lib/libpython3.11.so.1.0
#32 0x00007fc98f663cab in ?? () from /usr/lib/libpython3.11.so.1.0
#33 0x00007fc98f68f459 in ?? () from /usr/lib/libpython3.11.so.1.0
看看这个问题和之前一样吗?感觉是/lmdeploy/src/turbomind/models/llama/BlockManager.cc:104的问题
(gdb) up 1
#1 0x00007fc857b6bf7b in turbomind::BlockManager::BlockManager(unsigned long, double, int, turbomind::IAllocator*, std::function<unsigned long ()>) (this=0x7fc7d3fa1740, block_size=0, block_count=0.80000001192092896, chunk_size=-1, allocator=<optimized out>, get_free_size=...)
at /lmdeploy/src/turbomind/models/llama/BlockManager.cc:35
35 in /lmdeploy/src/turbomind/models/llama/BlockManager.cc
(gdb) up 1
#2 0x00007fc857b6f9c7 in __gnu_cxx::new_allocator<turbomind::BlockManager>::construct<turbomind::BlockManager, unsigned long&, double&, int&, turbomind::IAllocator*&, std::function<unsigned long ()>&>(turbomind::BlockManager*, unsigned long&, double&, int&, turbomind::IAllocator*&, std::function<unsigned long ()>&) (__p=0x7fc7d3fa1740, this=<optimized out>) at /opt/rh/devtoolset-9/root/usr/include/c++/9/new:174
warning: 174 /opt/rh/devtoolset-9/root/usr/include/c++/9/new: No such file or directory
(gdb) up 1
#3 std::allocator_traits<std::allocator<turbomind::BlockManager> >::construct<turbomind::BlockManager, unsigned long&, double&, int&, turbomind::IAllocator*&, std::function<unsigned long ()>&>(std::allocator<turbomind::BlockManager>&, turbomind::BlockManager*, unsigned long&, double&, int&, turbomind::IAllocator*&, std::function<unsigned long ()>&) (__p=0x7fc7d3fa1740, __a=...)
at /opt/rh/devtoolset-9/root/usr/include/c++/9/bits/alloc_traits.h:484
warning: 484 /opt/rh/devtoolset-9/root/usr/include/c++/9/bits/alloc_traits.h: No such file or directory
(gdb) up 1
#4 std::_Sp_counted_ptr_inplace<turbomind::BlockManager, std::allocator<turbomind::BlockManager>, (__gnu_cxx::_Lock_policy)2>::_Sp_counted_ptr_inplace<unsigned long&, double&, int&, turbomind::IAllocator*&, std::function<unsigned long ()>&>(std::allocator<turbomind::BlockManager>, unsigned long&, double&, int&, turbomind::IAllocator*&, std::function<unsigned long ()>&) (__a=..., this=0x7fc7d3fa1730)
at /opt/rh/devtoolset-9/root/usr/include/c++/9/bits/shared_ptr_base.h:548
warning: 548 /opt/rh/devtoolset-9/root/usr/include/c++/9/bits/shared_ptr_base.h: No such file or directory
(gdb) up 1
#5 std::__shared_count<(__gnu_cxx::_Lock_policy)2>::__shared_count<turbomind::BlockManager, std::allocator<turbomind::BlockManager>, unsigned long&, double&, int&, turbomind::IAllocator*&, std::function<unsigned long ()>&>(turbomind::BlockManager*&, std::_Sp_alloc_shared_tag<std::allocator<turbomind::BlockManager> >, unsigned long&, double&, int&, turbomind::IAllocator*&, std::function<unsigned long ()>&) (__a=...,
__p=<optimized out>, this=<optimized out>) at /opt/rh/devtoolset-9/root/usr/include/c++/9/bits/shared_ptr_base.h:679
679 in /opt/rh/devtoolset-9/root/usr/include/c++/9/bits/shared_ptr_base.h
(gdb) up 1
#6 std::__shared_ptr<turbomind::BlockManager, (__gnu_cxx::_Lock_policy)2>::__shared_ptr<std::allocator<turbomind::BlockManager>, unsigned long&, double&, int&, turbomind::IAllocator*&, std::function<unsigned long ()>&>(std::_Sp_alloc_shared_tag<std::allocator<turbomind::BlockManager> >, unsigned long&, double&, int&, turbomind::IAllocator*&, std::function<unsigned long ()>&) (__tag=..., this=<optimized out>)
at /opt/rh/devtoolset-9/root/usr/include/c++/9/bits/shared_ptr_base.h:1344
1344 in /opt/rh/devtoolset-9/root/usr/include/c++/9/bits/shared_ptr_base.h
(gdb) up 1
#7 std::shared_ptr<turbomind::BlockManager>::shared_ptr<std::allocator<turbomind::BlockManager>, unsigned long&, double&, int&, turbomind::IAllocator*&, std::function<unsigned long ()>&>(std::_Sp_alloc_shared_tag<std::allocator<turbomind::BlockManager> >, unsigned long&, double&, int&, turbomind::IAllocator*&, std::function<unsigned long ()>&) (__tag=..., this=<optimized out>)
at /opt/rh/devtoolset-9/root/usr/include/c++/9/bits/shared_ptr.h:359
warning: 359 /opt/rh/devtoolset-9/root/usr/include/c++/9/bits/shared_ptr.h: No such file or directory
(gdb) up 1
#8 std::allocate_shared<turbomind::BlockManager, std::allocator<turbomind::BlockManager>, unsigned long&, double&, int&, turbomind::IAllocator*&, std::function<unsigned long ()>&>(std::allocator<turbomind::BlockManager> const&, unsigned long&, double&, int&, turbomind::IAllocator*&, std::function<unsigned long ()>&) (__a=...) at /opt/rh/devtoolset-9/root/usr/include/c++/9/bits/shared_ptr.h:702
702 in /opt/rh/devtoolset-9/root/usr/include/c++/9/bits/shared_ptr.h
(gdb) up 1
#9 std::make_shared<turbomind::BlockManager, unsigned long&, double&, int&, turbomind::IAllocator*&, std::function<unsigned long ()>&>(unsigned long&, double&, int&, turbomind::IAllocator*&, std::function<unsigned long ()>&) ()
at /opt/rh/devtoolset-9/root/usr/include/c++/9/bits/shared_ptr.h:718
718 in /opt/rh/devtoolset-9/root/usr/include/c++/9/bits/shared_ptr.h
(gdb) up 1
#10 turbomind::SequenceManager::SequenceManager(unsigned long, turbomind::SequenceManager::BlockConfig const&, double, int, bool, int, turbomind::IAllocator*, std::function<unsigned long ()>) (this=0x7fc7d3904900, layer_num=<optimized out>, block_config=..., block_count=0.80000001192092896,
chunk_size=-1, enable_prefix_caching=<optimized out>, rank=<optimized out>, allocator=0x7fc7d38d9540, get_free_size=...)
at /lmdeploy/src/turbomind/models/llama/SequenceManager.cc:32
warning: 32 /lmdeploy/src/turbomind/models/llama/SequenceManager.cc: No such file or directory
(gdb) Quit
(gdb) Quit
(gdb) p block_config
$1 = (const turbomind::SequenceManager::BlockConfig &) @0x7fc7e67dbfd0: {head_dim_ = 128, head_num_ = 0, block_len_ = 64, t_bits_ = 0,
q_bits_ = 16}
(gdb) up 1
#11 0x00007fc857b61019 in turbomind::LlamaBatch<__half>::LlamaBatch (this=0x7fc7d3a0cc30, params=..., cache_block_seq_len=64,
quant_policy=<optimized out>, model=0x7fc7d3a13a30) at /lmdeploy/src/turbomind/models/llama/LlamaBatch.cc:965
warning: 965 /lmdeploy/src/turbomind/models/llama/LlamaBatch.cc: No such file or directory
(gdb) Quit
(gdb) up 1
#12 0x00007fc857b2ec3c in std::make_unique<turbomind::LlamaBatch<__half>, turbomind::EngineParams const&, int&, int&, turbomind::LlamaV2<__half>*>
() at /opt/rh/devtoolset-9/root/usr/include/c++/9/bits/unique_ptr.h:856
warning: 856 /opt/rh/devtoolset-9/root/usr/include/c++/9/bits/unique_ptr.h: No such file or directory
(gdb) down 1
#11 0x00007fc857b61019 in turbomind::LlamaBatch<__half>::LlamaBatch (this=0x7fc7d3a0cc30, params=..., cache_block_seq_len=64,
quant_policy=<optimized out>, model=0x7fc7d3a13a30) at /lmdeploy/src/turbomind/models/llama/LlamaBatch.cc:965
warning: 965 /lmdeploy/src/turbomind/models/llama/LlamaBatch.cc: No such file or directory
(gdb) p model
$2 = (turbomind::LlamaV2<__half> *) 0x7fc7d3a13a30
(gdb) p model->local_kv_head_num_
$3 = 0
结合代码来看,应该是 model->local_kv_head_num_=0导致BlockConfig的head_num_为0,最后除数为0 crash。
我明白了,local_kv_head_num_=0是LlamaV2构造函数中的赋值导致的local_kv_head_num_(kv_head_num / tensor_para.world_size_), 根据Gdb的显示
#13 turbomind::LlamaV2<__half>::LlamaV2 (this=0x7fc7d3a13a30, head_num=32, kv_head_num=2, size_per_head=<optimized out>,
inter_size=<optimized out>, num_layer=40, vocab_size=151552, norm_eps=<optimized out>, attn_params=..., start_id=0, end_id=151329,
cache_block_seq_len=64, quant_policy=0, use_context_fmha=true, engine_params=..., lora_params=...,
shared_state=std::shared_ptr<turbomind::LlamaV2<__half>::SharedState> (use count 9, weak count 0) = {...}, weights=0x7fc7c029c520,
tensor_para=..., stream=0x7fc7d0000bd0, cublas_wrapper=0x7fc7d3d6df10, allocator=0x7fc7d38d9540, peer_alloctor=0x7fc7d021b5a0,
is_free_buffer_after_forward=false, cuda_device_prop=0x7fc7d3a13620) at /lmdeploy/src/turbomind/models/llama/LlamaV2.cc:103
#14 0x00007fc857af458a in std::make_unique<turbomind::LlamaV2<__half>, unsigned long&, unsigned long&, unsigned long&, unsigned long&, unsigned long&, unsigned long&, float&, turbomind::LlamaAttentionParams&, int&, int&, int&, int&, int&, turbomind::EngineParams&, turbomind::LoraParams&, std::shared_ptr<turbomind::LlamaV2<__half>::SharedState>&, turbomind::LlamaWeight<__half>*, turbomind::NcclParam&, CUstream_st*&, turbomind::cublasMMWrapper*, turbomind::Allocator<(turbomind::AllocatorType)0>*, turbomind::Allocator<(turbomind::AllocatorType)0>*, bool, cudaDeviceProp*> ()
at /opt/rh/devtoolset-9/root/usr/include/c++/9/bits/unique_ptr.h:857
kv_head_num = 2,而我的运行指令是lmdeploy serve api_server Qwen/Qwen2-1.5B --server-port=8000 --tp=4 --model-name=default-model --max-batch-size=32 --session-len=32768 --log-level INFO,tensor_para.world_size_应该是4,所以得到的结果为0。这个边界条件应该处理一下 @irexyc
我把命令从
lmdeploy serve api_server Qwen/Qwen2-1.5B --server-port=8000 --tp=4 --model-name=default-model --max-batch-size=32 --session-len=32768 --log-level INFO
改到
lmdeploy serve api_server Qwen/Qwen2-1.5B --server-port=8000 --tp=2 --model-name=default-model --max-batch-size=32 --session-len=32768 --log-level INFO
减少--tp=4到--tp=2就能成功运行了,但是在处理请求时出现新的错误
lmdeploy serve api_server Qwen/Qwen2-1.5B --server-port=8000 --tp=1 --model-name=default-model --max-batch-size=32
Fetching 10 files: 100%|████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 32313.59it/s]
[WARNING] gemm_config.in is not found; using default GEMM algo
HINT: Please open http://0.0.0.0:8000 in a browser for detailed api usage!!!
HINT: Please open http://0.0.0.0:8000 in a browser for detailed api usage!!!
HINT: Please open http://0.0.0.0:8000 in a browser for detailed api usage!!!
INFO: Started server process [579577]
INFO: Waiting for application startup.
INFO: Application startup complete.
INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
INFO: 127.0.0.1:35988 - "POST /v1/chat/completions HTTP/1.1" 200 OK
terminate called after throwing an instance of 'std::runtime_error'
what(): [TM][ERROR] Assertion fail: /lmdeploy/src/turbomind/kernels/attention/attention.cu:35
zsh: IOT instruction (core dumped) lmdeploy serve api_server Qwen/Qwen2-1.5B --server-port=8000 --tp=1
估计是 2080 Ti 不支持 bf16
估计是 2080 Ti 不支持 bf16
@lzhangzz 使用convert或者其他方式可以转换bf16到f16吗?
另外我运行Qwen/Qwen2-7B是成功的
lmdeploy serve api_server Qwen/Qwen2-7B --server-port=8000 --tp=4 --model-name=default-model --max-batch-size=4 --session-len=32768
日志里也提到:
Fetching 14 files: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:00<00:00, 473.79it/s]
Device does not support bfloat16. Set float16 forcefully
[WARNING] gemm_config.in is not found; using default GEMM algo
[WARNING] gemm_config.in is not found; using default GEMM algo
[WARNING] gemm_config.in is not found; using default GEMM algo
[WARNING] gemm_config.in is not found; using default GEMM algo
HINT: Please open http://0.0.0.0:8000 in a browser for detailed api usage!!!
HINT: Please open http://0.0.0.0:8000 in a browser for detailed api usage!!!
HINT: Please open http://0.0.0.0:8000 in a browser for detailed api usage!!!
INFO: Started server process [2802635]
INFO: Waiting for application startup.
INFO: Application startup complete.
INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
有进行转换,但是Qwen/Qwen2-1.5B的日志里没这句话
之前有反馈说更新cublas版本就好了。可以试一下
pip3 install nvidia-cublas-cu12==12.3.4.1请问torch是什么版本的呢?
估计是 2080 Ti 不支持 bf16
@lzhangzz 使用convert或者其他方式可以转换bf16到f16吗?
Hi, @lpf6 2080Ti 不支持 bf16,但是 lmdeploy 会强制使用 fp16 来推理的。 对于 qwen2-1.5b,我在 2080 单卡上运行成功。但我没有多卡环境,无法复现你的问题 TP时,lmdeploy要求 attention_head_num, kv_head_num 都要被 tp 整除。这个约束在 v0.6.0a0 中 check 了
@Alwin4Zhang 从你的截图来看,tokenizer报错了 请先用 transformers中的AutoTokenizer 验证模型本身是否encode,decode都是正常的。
估计是 2080 Ti 不支持 bf16
@lzhangzz 使用convert或者其他方式可以转换bf16到f16吗?
Hi, @lpf6 2080Ti 不支持 bf16,但是 lmdeploy 会强制使用 fp16 来推理的。 对于 qwen2-1.5b,我在 2080 单卡上运行成功。但我没有多卡环境,无法复现你的问题 TP时,lmdeploy要求 attention_head_num, kv_head_num 都要被 tp 整除。这个约束在 v0.6.0a0 中 check 了
好的,我晚上试下用单卡运行一下
@lvhan028
环境: CentOS V100显卡 Driver Version: 550.54.14 lmdeploy==0.6.0 torch==2.3.1 tranformers==4.44.2
使用命令部署InternVL2-8B NVIDIA_VISIBLE_DEVICES=1 lmdeploy serve api_server /data/models/OpenGVLab/InternVL2-8B --tp 1 --server-port 11251 --model-name InternVL2-8B --cache-max-entry-count 0.25 推理时出现报错: Assertion fail: /lmdeploy/src/turbomind/kernels/attention/attention.cu:35
做过尝试: 1、lmdeploy降级到0.5.3 未解决 2、nvidia-cublas-cu12 升级到 12.1.3.1 未解决 3、torch 升级到 2.4.1 未解决 4、torch 降级到 2.2.2 解决
目前发现和torch版本有关,请问未来会去适配新版本的torch吗?
This issue is marked as stale because it has been marked as invalid or awaiting response for 7 days without any further response. It will be closed in 5 days if the stale label is not removed or if there is no further response.
This issue is closed because it has been stale for 5 days. Please open a new issue if you have similar issues or you have any new updates now.