WebGLM icon indicating copy to clipboard operation
WebGLM copied to clipboard

input length of input ids的长度大于1024

Open linuxonly801 opened this issue 1 year ago • 7 comments

配置好本地环境,使用WebGLM-2B模型。提问:Is HER2 gene a good target for treating cancer?

出现如下报错:

Input length of input_ids is 1056, but max_length is set to 1024. This can lead to unexpected behavior. You should consider increasing max_new_tokens. Traceback (most recent call last): File "cli_demo.py", line 21, in for results in webglm.stream_query(question): File "/media/WebGLM/model/modeling_webglm.py", line 49, in stream_query outputs = self.model.generate(**inputs, max_length=1024, eos_token_id = self.tokenizer.eop_token_id, pad_token_id=self.tokenizer.eop_token_id) File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context return func(*args, **kwargs) File "/usr/local/miniconda3/lib/python3.8/site-packages/transformers/generation/utils.py", line 1515, in generate return self.greedy_search( File "/usr/local/miniconda3/lib/python3.8/site-packages/transformers/generation/utils.py", line 2332, in greedy_search outputs = self( File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "/root/.cache/huggingface/modules/transformers_modules/modeling_glm.py", line 902, in forward model_output = self.glm(input_ids, position_ids, attention_mask, mems=mems, **kwargs) File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "/root/.cache/huggingface/modules/transformers_modules/modeling_glm.py", line 783, in forward transformer_output = self.transformer(embeddings, position_ids, attention_mask, mems) File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "/root/.cache/huggingface/modules/transformers_modules/modeling_glm.py", line 595, in forward hidden_states = layer(*args, mem=mem_i) File "/usr/local/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "/root/.cache/huggingface/modules/transformers_modules/modeling_glm.py", line 422, in forward layernorm_input = hidden_states + attention_output RuntimeError: CUDA error: device-side assert triggered CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect. For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

linuxonly801 avatar Jun 18 '23 09:06 linuxonly801

modeling_webglm.py里把1024改大

ilovesouthpark avatar Jun 23 '23 13:06 ilovesouthpark

outputs = self.model.generate(**inputs, max_length=2048, eos_token_id = self.tokenizer.eop_token_id, pad_token_id=self.tokenizer.eop_token_id) 修改后还是报错,求解决!! image

mikestut avatar Jun 23 '23 14:06 mikestut

用cpu跑会发现真正的错误原因是IndexError: index out of range in self,这个问题超过我的能力范围去解决了,但是应该是个普遍的问题,看看开发团队能否提供额外的参数让大家方便调整。

ilovesouthpark avatar Jun 25 '23 06:06 ilovesouthpark

关注中

traveler-vee avatar Jun 27 '23 03:06 traveler-vee

同遇到问题,中文问的时候会碰到,英文问目前还正常,看起出来跟中文搜索结果截出来的关键字没按照预订长度有关系

hanjingsu avatar Jul 06 '23 14:07 hanjingsu

IndexError: index out of range in self

TailyDuan avatar Jul 14 '23 10:07 TailyDuan

我搜索英文也会遇到这个问题,看到作者说不能修改max_length,所以只能做截断,但是我没有找到显式的截断api

在我这里,报错input_ids长度大于1024的原因不是用户输入的prompt过长,而是作者代码没有对搜索引擎搜索到的reference按规定长度截断,并将他们直接添加到了prompt中,导致input_ids大小超过1024。

解决方法是修改modeling_webglm.py中的query函数或stream_query函数。计算每个ref对应的token长度,限制总的prompt长度:

def query(self, question):
    refs = self.ref_retriever.query(question)
    if not refs:
        return { "references": [], "answer": "" }
    prompt = ''
    question = f'Question: {question}\\Answer: [gMASK]'
    total_token_num = self.tokenizer(question, return_tensors="pt").input_ids.shape[1]
    for ix, ref in enumerate(refs):
        txt = ref["text"]
        prompt_tmp = f'Reference [{ix+1}]: {txt}' '\\'
        prompt_tmp_token_num = self.tokenizer(prompt_tmp, return_tensors="pt").input_ids.shape[1]
        if total_token_num + prompt_tmp_token_num < 900:
            prompt += prompt_tmp
            total_token_num += prompt_tmp_token_num
        else:
            break
    prompt += question
    inputs = self.tokenizer(prompt, return_tensors="pt")
    # other code

尽管这样做很简单(会由于超过长度限制的ref[i]而忽略符合长度限制的ref[i+1])且效率有点低,但确实对我有用:)


我研究了一下代码,model/retriever/extracting/__init__.pyExtractor类的_pre_filter方法应该是在约束检索到的每个url对应页面中每一个段落的长度

WnQinm avatar Jun 07 '24 08:06 WnQinm