OpenChatKit icon indicating copy to clipboard operation
OpenChatKit copied to clipboard

Bug when running inference with retrieval augmented model

Open mauriceweber opened this issue 1 year ago • 1 comments

Describe the bug Using retrieval-augmented models, a sequence of prompts leads to a runtime error (size mismatch between two tensors).

To Reproduce Steps to reproduce the behavior:

  1. After downloading the Wikipedia index, run inference using python inference/bot.py --retrieval
  2. In the OpenChatKit Shell, run the following set of queries:
>>> Where is Bern?
...
>>> Where is Switzerland?
...
>>> Is Switzerland in Europe or in America?

Traceback The queries lead to the following error:

Traceback (most recent call last):
  File "/home/fsuser/OpenChatKit/inference/bot.py", line 185, in <module>
    main()
  File "/home/fsuser/OpenChatKit/inference/bot.py", line 181, in main
    ).cmdloop()
  File "/home/fsuser/miniconda3/envs/OpenChatKit/lib/python3.10/cmd.py", line 138, in cmdloop
    stop = self.onecmd(line)
  File "/home/fsuser/miniconda3/envs/OpenChatKit/lib/python3.10/cmd.py", line 217, in onecmd
    return func(arg)
  File "/home/fsuser/OpenChatKit/inference/bot.py", line 87, in do_say
    output = self._model.do_inference(
  File "/home/fsuser/OpenChatKit/inference/bot.py", line 32, in do_inference
    outputs = self._model.generate(
  File "/home/fsuser/miniconda3/envs/OpenChatKit/lib/python3.10/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/home/fsuser/miniconda3/envs/OpenChatKit/lib/python3.10/site-packages/transformers/generation_utils.py", line 1326, in generate
    return self.sample(
  File "/home/fsuser/miniconda3/envs/OpenChatKit/lib/python3.10/site-packages/transformers/generation_utils.py", line 1944, in sample
    outputs = self(
  File "/home/fsuser/miniconda3/envs/OpenChatKit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/fsuser/miniconda3/envs/OpenChatKit/lib/python3.10/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 619, in forward
    outputs = self.gpt_neox(
  File "/home/fsuser/miniconda3/envs/OpenChatKit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/fsuser/miniconda3/envs/OpenChatKit/lib/python3.10/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 511, in forward
    outputs = layer(
  File "/home/fsuser/miniconda3/envs/OpenChatKit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/fsuser/miniconda3/envs/OpenChatKit/lib/python3.10/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 319, in forward
    attention_layer_outputs = self.attention(
  File "/home/fsuser/miniconda3/envs/OpenChatKit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/fsuser/miniconda3/envs/OpenChatKit/lib/python3.10/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 153, in forward
    attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
  File "/home/fsuser/miniconda3/envs/OpenChatKit/lib/python3.10/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 220, in _attn
    attn_scores = torch.where(causal_mask, attn_scores, mask_value)
RuntimeError: The size of tensor a (2048) must match the size of tensor b (2247) at non-singleton dimension 3

Environment Setup using mamba in root dir: mamba env create -f environment.yml

Hardware:

  • OS: Ubuntu 20.04.5 LTS
  • 1x A100 80G GPU
  • 8 vCPU with 128GB RAM

mauriceweber avatar Mar 13 '23 17:03 mauriceweber

Thank you for the detailed bug report. Let me try to reproduce this.

csris avatar Mar 18 '23 05:03 csris