cogvlm推理选用大于1的num_beam会报错。当设置为2时报错如下:
hin a limit of 100 words. Answer:[OUTPUT]Traceback (most recent call last):
File "/root/swift/swift/cli/infer.py", line 5, in
infer_main()
File "/root/swift/swift/utils/run_utils.py", line 31, in x_main
result = llm_x(args, **kwargs)
File "/root/swift/swift/llm/infer.py", line 447, in llm_infer
response, _ = inference(
File "/root/swift/swift/llm/utils/utils.py", line 709, in inference
generate_ids = model.generate(
File "/root/miniconda3/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/root/miniconda3/lib/python3.10/site-packages/transformers/generation/utils.py", line 1665, in generate
return self.beam_sample(
File "/root/miniconda3/lib/python3.10/site-packages/transformers/generation/utils.py", line 3411, in beam_sample
outputs = self(
File "/root/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/root/.cache/huggingface/modules/transformers_modules/cogvlm-chat/modeling_cogvlm.py", line 660, in forward
outputs = self.model(
File "/root/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/root/.cache/huggingface/modules/transformers_modules/cogvlm-chat/modeling_cogvlm.py", line 426, in forward
assert len(input_ids) == len(images), f"{len(input_ids)} {len(images)}"
AssertionError: 2 1
如果num_beams设置为4,assertionError就是4 1。
运行的sh如下
CUDA_VISIBLE_DEVICES=0 swift infer
--ckpt_dir /root/autodl-tmp/output/uniCate2k-prune-cogvlm-lora-rk8-b1-2e5/checkpoint-6000
--load_dataset_config false
--custom_val_dataset_path /root/autodl-tmp/dataset/arch_50.jsonl
--save_result
--verbose true
--temperature 0.1
--top_k 3
--top_p 0.9
--repetition_penalty 1.1
--val_dataset_sample -1
--num_beam 4