OpenChatKit
OpenChatKit copied to clipboard
Bug when running inference with retrieval augmented model
Describe the bug Using retrieval-augmented models, a sequence of prompts leads to a runtime error (size mismatch between two tensors).
To Reproduce Steps to reproduce the behavior:
- After downloading the Wikipedia index, run inference using
python inference/bot.py --retrieval
- In the OpenChatKit Shell, run the following set of queries:
>>> Where is Bern?
...
>>> Where is Switzerland?
...
>>> Is Switzerland in Europe or in America?
Traceback The queries lead to the following error:
Traceback (most recent call last):
File "/home/fsuser/OpenChatKit/inference/bot.py", line 185, in <module>
main()
File "/home/fsuser/OpenChatKit/inference/bot.py", line 181, in main
).cmdloop()
File "/home/fsuser/miniconda3/envs/OpenChatKit/lib/python3.10/cmd.py", line 138, in cmdloop
stop = self.onecmd(line)
File "/home/fsuser/miniconda3/envs/OpenChatKit/lib/python3.10/cmd.py", line 217, in onecmd
return func(arg)
File "/home/fsuser/OpenChatKit/inference/bot.py", line 87, in do_say
output = self._model.do_inference(
File "/home/fsuser/OpenChatKit/inference/bot.py", line 32, in do_inference
outputs = self._model.generate(
File "/home/fsuser/miniconda3/envs/OpenChatKit/lib/python3.10/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/home/fsuser/miniconda3/envs/OpenChatKit/lib/python3.10/site-packages/transformers/generation_utils.py", line 1326, in generate
return self.sample(
File "/home/fsuser/miniconda3/envs/OpenChatKit/lib/python3.10/site-packages/transformers/generation_utils.py", line 1944, in sample
outputs = self(
File "/home/fsuser/miniconda3/envs/OpenChatKit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/fsuser/miniconda3/envs/OpenChatKit/lib/python3.10/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 619, in forward
outputs = self.gpt_neox(
File "/home/fsuser/miniconda3/envs/OpenChatKit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/fsuser/miniconda3/envs/OpenChatKit/lib/python3.10/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 511, in forward
outputs = layer(
File "/home/fsuser/miniconda3/envs/OpenChatKit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/fsuser/miniconda3/envs/OpenChatKit/lib/python3.10/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 319, in forward
attention_layer_outputs = self.attention(
File "/home/fsuser/miniconda3/envs/OpenChatKit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/fsuser/miniconda3/envs/OpenChatKit/lib/python3.10/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 153, in forward
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
File "/home/fsuser/miniconda3/envs/OpenChatKit/lib/python3.10/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 220, in _attn
attn_scores = torch.where(causal_mask, attn_scores, mask_value)
RuntimeError: The size of tensor a (2048) must match the size of tensor b (2247) at non-singleton dimension 3
Environment
Setup using mamba in root dir: mamba env create -f environment.yml
Hardware:
- OS: Ubuntu 20.04.5 LTS
- 1x A100 80G GPU
- 8 vCPU with 128GB RAM
Thank you for the detailed bug report. Let me try to reproduce this.