blog
blog copied to clipboard
Assisted generation errors due to use_cache
Following the instructions in the blog post for assisted generation, I run into some issues. (FYI, both the longform_model and assistant_model are finetuned versions of OPT, which is the exact same model used in the blog post.)
First, when I do exactly what's in the post:
prompt = prompt + "\nAnswer:"
inputs = tokenizer([prompt], return_tensors="pt").to("cuda")
outputs = longform_model.generate(**inputs, assistant_model=assistant_model)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
I get an error telling me that assisted generation requires use_cache=True
. Hmm... weird, and the blog post didn't seem to need to use that argument, but okay, let's try it!
prompt = prompt + "\nAnswer:"
inputs = tokenizer([prompt], return_tensors="pt").to("cuda")
outputs = longform_model.generate(**inputs, assistant_model=assistant_model, use_cache=True)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
Then this happens:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-10-e9645bbc79d4> in <module>
----> 1 generate_from_prompt("Which is a species of fish? Tope or rope?")
<ipython-input-9-14fc80d284ea> in generate_from_prompt(prompt)
2 prompt = prompt + "\nAnswer:"
3 inputs = tokenizer([prompt], return_tensors="pt").to("cuda")
----> 4 outputs = longform_model.generate(**inputs, assistant_model=assistant_model, use_cache=True)
5 print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
/usr/lib/python3/dist-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
25 def decorate_context(*args, **kwargs):
26 with self.clone():
---> 27 return func(*args, **kwargs)
28 return cast(F, decorate_context)
29
~/.local/lib/python3.8/site-packages/transformers/generation/utils.py in generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, **kwargs)
1493
1494 # 12. run assisted generate
-> 1495 return self.assisted_decoding(
1496 input_ids,
1497 assistant_model=assistant_model,
~/.local/lib/python3.8/site-packages/transformers/generation/utils.py in assisted_decoding(self, input_ids, assistant_model, do_sample, logits_processor, logits_warper, stopping_criteria, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, streamer, **model_kwargs)
4253 # 1.1. use the assistant model to obtain the next candidate logits
4254 if "assistant_past_key_values" in model_kwargs:
-> 4255 prev_seq_len = model_kwargs["assistant_past_key_values"][0][assistant_kv_indexing].shape[-2]
4256 # `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model)
4257 new_token_len = candidate_input_ids.shape[1] - prev_seq_len
TypeError: 'NoneType' object is not subscriptable
I'm using bleeding edge version of Transformers, so I'm curious what I'm doing wrong here, or else maybe this is just a bug.