FlagAI
FlagAI copied to clipboard
RuntimeError: shape '[1, 1, 1, 32, 128]' is invalid for input of size 16384
System Info
flagai 1.7.1 flash-attn 1.0.4 torch 1.13.1 python3.8 linux
Information
- [X] The official example scripts
- [ ] My own modified scripts
Tasks
- [X] An officially supported task in the
examplesfolder (such as T5/AltCLIP, ...) - [ ] My own task or dataset (give details below)
Reproduction
运行readme中的样例:
import os
import torch
from flagai.auto_model.auto_loader import AutoLoader
from flagai.model.predictor.predictor import Predictor
from flagai.data.tokenizer import Tokenizer
import bminf
state_dict = "./checkpoints_in/"
model_name = 'aquila-7b' # 'aquila-33b'
loader = AutoLoader(
"lm",
model_dir=state_dict,
model_name=model_name,
use_cache=True)
model = loader.get_model()
tokenizer = loader.get_tokenizer()
model.eval()
model.half()
model.cuda()
predictor = Predictor(model, tokenizer)
text = "北京在哪儿?"
text = f'{text}'
print(f"text is {text}")
with torch.no_grad():
out = predictor.predict_generate_randomsample(text, out_max_length=200, temperature=0)
print(f"pred is {out}")
会报错:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "./lib/python3.8/site-packages/flagai/model/predictor/predictor.py", line 352, in predict_generate_randomsample
return aquila_generate(self.tokenizer, self.model,
File "./lib/python3.8/site-packages/flagai/model/predictor/aquila.py", line 37, in aquila_generate
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)["logits"]
File "./lib/python3.8/site-packages/flagai/model/aquila_model.py", line 225, in forward
h = layer(h, freqs_cis, mask)
File "./lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "./lib/python3.8/site-packages/flagai/model/blocks/aquila_block.py", line 55, in forward
h = x + self.attention.forward(self.attention_norm(x), self.start_pos, freqs_cis, mask, self.use_cache)
File "./lib/python3.8/site-packages/flagai/model/layers/attentions.py", line 189, in forward
keys = keys.view(bsz, seqlen, 1, self.n_local_heads, self.head_dim)
RuntimeError: shape '[1, 1, 1, 32, 128]' is invalid for input of size 16384
Expected behavior
不知道我是不是哪个包安装错了?
同样问题。。。
flagai 1.7.1 flash-attn 1.0.7 torch 2.0.0+cu118 python3.8.10 linux
运行readme中的样例:
import os
import torch
from flagai.auto_model.auto_loader import AutoLoader
from flagai.model.predictor.predictor import Predictor
from flagai.model.predictor.aquila import aquila_generate
from flagai.data.tokenizer import Tokenizer
import bminf
state_dict = "./checkpoints_in/"
model_name = 'aquilachat-7b'
loader = AutoLoader(
"lm",
model_dir=state_dict,
model_name=model_name,
use_cache=True)
model = loader.get_model()
tokenizer = loader.get_tokenizer()
cache_dir = os.path.join(state_dict, model_name)
model.eval()
model.half()
model.cuda()
predictor = Predictor(model, tokenizer)
text = "北京为什么是中国的首都?"
print('-'*80)
print(f"text is {text}")
from cyg_conversation import default_conversation
conv = default_conversation.copy()
conv.append_message(conv.roles[0], text)
conv.append_message(conv.roles[1], None)
tokens = tokenizer.encode_plus(f"{conv.get_prompt()}", None, max_length=None)['input_ids']
tokens = tokens[1:-1]
with torch.no_grad():
out = aquila_generate(tokenizer, model, [text], max_gen_len:=200, top_p=0.95, prompts_tokens=[tokens])
print(f"pred is {out}")
File ~/miniconda3/lib/python3.8/site-packages/flagai/model/layers/attentions.py:189, in AQUILAAttention.forward(self, x, start_pos, freqs_cis, mask, use_cache, **kwargs)
187 if self.config.flash_atten or (self.config.flash_atten_aquila_style and not self.training):
188 xq = xq.view(bsz, seqlen, 1, self.n_local_heads, self.head_dim)
--> 189 keys = keys.view(bsz, seqlen, 1, self.n_local_heads, self.head_dim)
190 values = values.view(bsz, seqlen, 1, self.n_local_heads, self.head_dim)
191 qkv = torch.concat([xq, keys, values], dim=2)
RuntimeError: shape '[1, 1, 1, 32, 128]' is invalid for input of size 176128
可以修改下 checkpoints_in/aquilachat-7b/config.json "flash_atten": true 修改成 "flash_atten": false
或者删除 checkpoints_in/aquilachat-7b/config.json 重试下