[Bug] AttributeError: 'DeepseekVLV2ForCausalLM' object has no attribute 'config'
Checklist
- [x] 1. I have searched related issues but cannot get the expected help.
- [x] 2. The bug has not been fixed in the latest version.
- [x] 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
Describe the bug
The argument trust_remote_code is to be used with Auto classes. It has no effect here and is ignored.
You are using a model of type deepseek_vl_v2 to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
The argument trust_remote_code is to be used with Auto classes. It has no effect here and is ignored.
You are using a model of type deepseek_vl_v2 to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
2025-04-09 15:26:34,083 - lmdeploy - WARNING - supported_models.py:119 - AutoConfig.from_pretrained failed for /data2/jwq/grec_dataset_construct/experiment/batch_model_test/models/deepseek-ai/deepseek-vl2-tiny. Exception: The checkpoint you are trying to load has model type deepseek_vl_v2 but Transformers does not recognize this architecture. This could be because of an issue with the checkpoint, or because your version of Transformers is out of date.
2025-04-09 15:26:34,083 - lmdeploy - WARNING - archs.py:55 - Try to run with pytorch engine because /data2/jwq/grec_dataset_construct/experiment/batch_model_test/models/deepseek-ai/deepseek-vl2-tiny is not explicitly supported by lmdeploy.
The argument trust_remote_code is to be used with Auto classes. It has no effect here and is ignored.
You are using a model of type deepseek_vl_v2 to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>. This is expected, and simply means that the legacy (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set legacy=False. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message.
Some kwargs in processor config are unused and will not have any effect: image_token, image_mean, normalize, pad_token, sft_format, candidate_resolutions, patch_size, add_special_token, image_std, mask_prompt, downsample_ratio, ignore_id.
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using tokenizers before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using tokenizers before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
2025-04-09 15:26:36,463 - lmdeploy - WARNING - transformers.py:22 - LMDeploy requires transformers version: [4.33.0 ~ 4.46.1], but found version: 4.46.3
Loading weights from safetensors: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00, 1.68s/it]
2025-04-09 15:26:55,337 - lmdeploy - WARNING - async_engine.py:643 - GenerationConfig: GenerationConfig(n=1, max_new_tokens=512, do_sample=False, top_p=1.0, top_k=50, min_p=0.0, temperature=0.8, repetition_penalty=1.0, ignore_eos=False, random_seed=None, stop_words=None, bad_words=None, stop_token_ids=[1], bad_token_ids=None, min_new_tokens=None, skip_special_tokens=True, spaces_between_special_tokens=True, logprobs=None, response_format=None, logits_processors=None, output_logits=None, output_last_hidden_state=None)
2025-04-09 15:26:55,337 - lmdeploy - WARNING - async_engine.py:644 - Since v0.6.0, lmdeploy add do_sample in GenerationConfig. It defaults to False, meaning greedy decoding. Please set do_sample=True if sampling decoding is needed
2025-04-09 15:26:56,074 - lmdeploy - ERROR - model_agent.py:391 - Task <ModelAgentLoop> failed
Traceback (most recent call last):
File "/data2/jwq/anaconda3/envs/lmdeploy/lib/python3.8/site-packages/lmdeploy/pytorch/engine/model_agent.py", line 386, in _on_finish_callback
task.result()
File "/data2/jwq/anaconda3/envs/lmdeploy/lib/python3.8/site-packages/lmdeploy/pytorch/engine/model_agent.py", line 374, in _async_loop_background
await self._async_step_background(
File "/data2/jwq/anaconda3/envs/lmdeploy/lib/python3.8/site-packages/lmdeploy/pytorch/engine/model_agent.py", line 322, in _async_step_background
output = await self._async_model_forward(inputs,
File "/data2/jwq/anaconda3/envs/lmdeploy/lib/python3.8/site-packages/lmdeploy/pytorch/engine/model_agent.py", line 243, in _async_model_forward
ret = await __forward(inputs)
File "/data2/jwq/anaconda3/envs/lmdeploy/lib/python3.8/site-packages/lmdeploy/pytorch/engine/model_agent.py", line 220, in __forward
return await self.async_forward(inputs, swap_in_map=swap_in_map, swap_out_map=swap_out_map)
File "/data2/jwq/anaconda3/envs/lmdeploy/lib/python3.8/site-packages/lmdeploy/pytorch/engine/model_agent.py", line 538, in async_forward
output = self._forward_impl(inputs, swap_in_map=swap_in_map, swap_out_map=swap_out_map)
File "/data2/jwq/anaconda3/envs/lmdeploy/lib/python3.8/site-packages/lmdeploy/pytorch/engine/model_agent.py", line 521, in _forward_impl
output = model_forward(
File "/data2/jwq/anaconda3/envs/lmdeploy/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/data2/jwq/anaconda3/envs/lmdeploy/lib/python3.8/site-packages/lmdeploy/pytorch/engine/model_agent.py", line 75, in model_forward
output = model(**input_dict)
File "/data2/jwq/anaconda3/envs/lmdeploy/lib/python3.8/site-packages/lmdeploy/pytorch/backends/cuda/graph_runner.py", line 155, in call
runner.capture(**kwargs)
File "/data2/jwq/anaconda3/envs/lmdeploy/lib/python3.8/site-packages/lmdeploy/pytorch/backends/cuda/graph_runner.py", line 69, in capture
self.meta.input_buffers = self.model.make_buffers_cudagraph(self.meta, **kwargs)
File "/data2/jwq/anaconda3/envs/lmdeploy/lib/python3.8/site-packages/lmdeploy/pytorch/models/utils/cudagraph.py", line 63, in make_buffers_cudagraph
if getattr(self.config, 'use_flash_mla', False) is True:
File "/data2/jwq/anaconda3/envs/lmdeploy/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1729, in getattr
raise AttributeError(f"'{type(self).name}' object has no attribute '{name}'")
AttributeError: 'DeepseekVLV2ForCausalLM' object has no attribute 'config'
Reproduction
import os from lmdeploy import pipeline, GenerationConfig, PytorchEngineConfig from lmdeploy.vl import load_image
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
if name == "main": # 加载模型 pipe = pipeline('deepseek-ai/deepseek-vl2-tiny')
# 加载图片
image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')
# 调用模型
response = pipe(('<IMAGE_TOKEN>describe this image', image))
print(response)
Environment
sys.platform: linux
Python: 3.8.20 | packaged by conda-forge | (default, Sep 30 2024, 17:52:49) [GCC 13.3.0]
CUDA available: True
MUSA available: False
numpy_random_seed: 2147483648
GPU 0,1,2,3: NVIDIA RTX A6000
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 12.6, V12.6.85
GCC: gcc (Ubuntu 9.4.0-1ubuntu1~20.04.3) 9.4.0
PyTorch: 2.4.1+cu121
PyTorch compiling details: PyTorch built with:
- GCC 9.3
- C++ Version: 201703
- Intel(R) oneAPI Math Kernel Library Version 2022.2-Product Build 20220804 for Intel(R) 64 architecture applications
- Intel(R) MKL-DNN v3.4.2 (Git Hash 1137e04ec0b5251ca2b4400a4fd3c667ce843d67)
- OpenMP 201511 (a.k.a. OpenMP 4.5)
- LAPACK is enabled (usually provided by MKL)
- NNPACK is enabled
- CPU capability usage: AVX512
- CUDA Runtime 12.1
- NVCC architecture flags: -gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_90,code=sm_90
- CuDNN 90.1 (built against CUDA 12.4)
- Magma 2.6.1
- Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=12.1, CUDNN_VERSION=9.1.0, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wsuggest-override -Wno-psabi -Wno-error=pedantic -Wno-error=old-style-cast -Wno-missing-braces -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=2.4.1, USE_CUDA=ON, USE_CUDNN=ON, USE_CUSPARSELT=1, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_GLOO=ON, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=1, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, USE_ROCM_KERNEL_ASSERT=OFF,
TorchVision: 0.19.1+cu121
LMDeploy: 0.7.2.post1+
transformers: 4.46.3
gradio: Not Found
fastapi: 0.115.12
pydantic: 2.10.6
triton: 3.0.0
NVIDIA Topology:
GPU0 GPU1 GPU2 GPU3 CPU Affinity NUMA Affinity GPU NUMA ID
GPU0 X SYS SYS SYS 0-47 0 N/A
GPU1 SYS X SYS SYS 0-47 0 N/A
GPU2 SYS SYS X SYS 0-47 0 N/A
GPU3 SYS SYS SYS X 0-47 0 N/A
Legend:
X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks
Error traceback
have the same issue. v0.7.2
https://github.com/InternLM/lmdeploy/blob/f50a1f44ea41c2dab2252a5468e755a2d7a3348c/lmdeploy/pytorch/models/deepseek_vl2.py#L113 add: self.config = config
lmdeploy/lmdeploy/pytorch/models/deepseek_vl2.py
Line 113 in f50a1f4
add: self.config = config
thanks @hufangjian this resolved the issue!
Thanks for pointing out and giving your solutions. This should be fixed in the above PR.
cc @zhulinJulia24
May add deepseek-vl2 into test cases