rust-bert
rust-bert copied to clipboard
Empty response
I think this is separate to #85 so I have created a new issue. The little chatbot I am writing always gives empty blank responses after about 7-12 rounds of conversation.
- Fined tuned transcript
You: Hello Holly
Holly: My name is holly. I love talking about magic and the goddess of nature, it's so special to me as a fairy.
You: So what can you tell me about magic. I am very interested in it.
Holly: I am really interested in magic.
You: That makes sense because you have magic
Holly: Magic can be very powerful. Don't you think?
You: I think it can be very powerful
Holly: Magic can do many wonderful things.
You: Can you tell me what magic can do?
Holly: I want to learn.
You: I think you already know
Holly:
You: So all quite again
Holly:
(p.s. Was surprised it knows its name and the first reply to hello was quite a good description of what the training data conversations were all about.)
This happens on my model and the default DialgoGPT one. Even when I set min_length: 2,
in the configuration.
- DialgoGPT transcript
You: Hello again
Holly: Hey man
You: What make you think I am a man
Holly: I meant as in what do you think the human body is like. I don't mean physical appearance.
You: Well that is an odd way to generalise, but anywyas you were right
Holly: You know what, I guess it would be a bit odd, but I'm sure it'll do well.
You: Most odd things do well
Holly: I don't know what that means but I'm sure they are good.
You: Its just the funny way nature works
Holly: Yeah... It's the good kind of weird
You: Right you get it
Holly: Aye
You: So are you going to go quite too?
Holly:
You: Yep quite
Holly:
You:
Any ideas? Or should I start tracing down the rabbit hole again?
This also happens with your example if you add more lines to it.
extern crate anyhow;
use rust_bert::pipelines::conversation::{ConversationManager, ConversationModel};
fn main() -> anyhow::Result<()> {
let conversation_model = ConversationModel::new(Default::default())?;
let mut conversation_manager = ConversationManager::new();
let conversation_1_id =
conversation_manager.create("Going to the movies tonight - any suggestions?");
let output = conversation_model.generate_responses(&mut conversation_manager);
println!("{:?}", output);
let _ = conversation_manager
.get(&conversation_1_id)
.unwrap()
.add_user_input("Is it an action movie?");
let output = conversation_model.generate_responses(&mut conversation_manager);
println!("{:?}", output);
let _ = conversation_manager
.get(&conversation_1_id)
.unwrap()
.add_user_input("Is it a love movie?");
let output = conversation_model.generate_responses(&mut conversation_manager);
println!("{:?}", output);
let _ = conversation_manager
.get(&conversation_1_id)
.unwrap()
.add_user_input("What is it about?");
let output = conversation_model.generate_responses(&mut conversation_manager);
println!("{:?}", output);
let _ = conversation_manager
.get(&conversation_1_id)
.unwrap()
.add_user_input("Would you recommend it?");
let output = conversation_model.generate_responses(&mut conversation_manager);
println!("{:?}", output);
let output = conversation_model.generate_responses(&mut conversation_manager);
println!("{:?}", output);
let _ = conversation_manager
.get(&conversation_1_id)
.unwrap()
.add_user_input("If not what would you recommend?");
let output = conversation_model.generate_responses(&mut conversation_manager);
println!("{:?}", output);
let _ = conversation_manager
.get(&conversation_1_id)
.unwrap()
.add_user_input("I think you need to think about it more.");
let output = conversation_model.generate_responses(&mut conversation_manager);
println!("{:?}", output);
let _ = conversation_manager
.get(&conversation_1_id)
.unwrap()
.add_user_input("After all action is the best.");
let output = conversation_model.generate_responses(&mut conversation_manager);
println!("{:?}", output);
let _ = conversation_manager
.get(&conversation_1_id)
.unwrap()
.add_user_input("But maybe not");
let output = conversation_model.generate_responses(&mut conversation_manager);
println!("{:?}", output);
let _ = conversation_manager
.get(&conversation_1_id)
.unwrap()
.add_user_input("What really matters is quality.");
let output = conversation_model.generate_responses(&mut conversation_manager);
println!("{:?}", output);
let _ = conversation_manager
.get(&conversation_1_id)
.unwrap()
.add_user_input("Quality over all other things");
let output = conversation_model.generate_responses(&mut conversation_manager);
println!("{:?}", output);
let _ = conversation_manager
.get(&conversation_1_id)
.unwrap()
.add_user_input("But not at the expense of tradition");
let output = conversation_model.generate_responses(&mut conversation_manager);
println!("{:?}", output);
let _ = conversation_manager
.get(&conversation_1_id)
.unwrap()
.add_user_input("For advancement for advancments sake must be curtailed");
let output = conversation_model.generate_responses(&mut conversation_manager);
println!("{:?}", output);
let _ = conversation_manager
.get(&conversation_1_id)
.unwrap()
.add_user_input("Unethical practises must be trimmed");
let output = conversation_model.generate_responses(&mut conversation_manager);
println!("{:?}", output);
let _ = conversation_manager
.get(&conversation_1_id)
.unwrap()
.add_user_input("In truth nothing is of any good");
let output = conversation_model.generate_responses(&mut conversation_manager);
println!("{:?}", output);
let _ = conversation_manager
.get(&conversation_1_id)
.unwrap()
.add_user_input("Unless it is traditional");
let output = conversation_model.generate_responses(&mut conversation_manager);
println!("{:?}", output);
let _ = conversation_manager
.get(&conversation_1_id)
.unwrap()
.add_user_input("And sometimes not even then");
let output = conversation_model.generate_responses(&mut conversation_manager);
println!("{:?}", output);
Ok(())
}
{7d07290f-e99b-4d42-b21c-c1486c368d84: "The Departed is pretty good!"}
{7d07290f-e99b-4d42-b21c-c1486c368d84: "It\'s a movie with action and suspense. Watch it!"}
{7d07290f-e99b-4d42-b21c-c1486c368d84: "I don\'t think so..."}
{7d07290f-e99b-4d42-b21c-c1486c368d84: "It\'s a love story. It\'s called The Departed because the characters have all of the emotion in their eyes throughout the whole movie."}
{7d07290f-e99b-4d42-b21c-c1486c368d84: "Nope, I watched it. Worth it."}
{}
{7d07290f-e99b-4d42-b21c-c1486c368d84: "I\'d say yes!"}
{7d07290f-e99b-4d42-b21c-c1486c368d84: "It\'s aight"}
{7d07290f-e99b-4d42-b21c-c1486c368d84: "I\'ll"}
{7d07290f-e99b-4d42-b21c-c1486c368d84: "So there\'s!"}
{7d07290f-e99b-4d42-b21c-c1486c368d84: "Yee"}
{7d07290f-e99b-4d42-b21c-c1486c368d84: "Yes..."}
{7d07290f-e99b-4d42-b21c-c1486c368d84: ""}
{7d07290f-e99b-4d42-b21c-c1486c368d84: ""}
{7d07290f-e99b-4d42-b21c-c1486c368d84: ""}
{7d07290f-e99b-4d42-b21c-c1486c368d84: ""}
{7d07290f-e99b-4d42-b21c-c1486c368d84: ""}
{7d07290f-e99b-4d42-b21c-c1486c368d84: ""}
I am wondering if this is to do with how the model was trained with a max context size of 7 previous sentences. I am trying to find a way to limit the number of contex that goes into the model by reading though fn concat_input_history
but I am not finding it. Am I barking up the wrong tree?
So here the part of the code related to maintaing minumum length
// Do not allow eos token if min length is not reached
trace!("gen_opt.min_length: {:?}", gen_opt.min_length);
trace!("current_length: {:?}", current_length);
trace!("gen_opt.eos_token_ids.is_some(): {:?}", gen_opt.eos_token_ids.is_some());
if (gen_opt.eos_token_ids.is_some()) & (current_length < gen_opt.min_length) {
trace!("Min length not reached: {:?}", gen_opt.eos_token_ids.is_some());
let _ = next_token_logits.index_fill_(
1,
&Tensor::of_slice(gen_opt.eos_token_ids.as_ref().unwrap())
.to(next_token_logits.device()),
std::f64::NEG_INFINITY,
);
}
I added some traces in front and got this:
TRACE rust_bert::pipelines::generation::private_generation_utils > gen_opt.min_length: 2
TRACE rust_bert::pipelines::generation::private_generation_utils > current_length: 46
TRACE rust_bert::pipelines::generation::private_generation_utils > gen_opt.eos_token_ids.is_some(): true
The current length suprised me as it was the begining of the loop before it had added any tokens. Surely this should be zero. However tracking back the value of current_length
is initialised with cur_len
which comes from the function input which is initalised with:
let cur_len = if !self.is_encoder_decoder() {
*input_ids.size().last().unwrap()
} else {
1
};
I think the minium length should pertain to the current output not the output + input length. So should the if block be:
if (gen_opt.eos_token_ids.is_some()) & ((current_length - cur_len) < gen_opt.min_length) {
trace!("Min length not reached: {:?}", gen_opt.eos_token_ids.is_some());
let _ = next_token_logits.index_fill_(
1,
&Tensor::of_slice(gen_opt.eos_token_ids.as_ref().unwrap())
.to(next_token_logits.device()),
std::f64::NEG_INFINITY,
);
}
So if I do the min_length check as show above the min_length
in the config is honored during chatting and no more empty responses yay. I do wonder if I am breaking something else though. Please let me know your thoughts.
Also what happens when the length of inputs exceeds max length
while current_length < gen_opt.max_length {
Wouldn't this while loop not even run, and therefore no output at all?
Hi @QuantumEntangledAndy ,
Thank you for raising this behaviour: I have noticed the same while troubleshooting the other issue. I tried to run a dialogue with the Python's Transformers implementation and the behaviour is identical: the model tends to stop generating after 6 to 7 rounds rather consistently.
The current length suprised me as it was the begining of the loop before it had added any tokens. Surely this should be zero. However tracking back the value of
current_length
is initialised withcur_len
which comes from the function input which is initalised with:let cur_len = if !self.is_encoder_decoder() { *input_ids.size().last().unwrap() } else { 1 };
You are right that this behaviour is rather surprising. Your suggested changes sound reasonable and would reflect more accurately the expected behaviour of min_length
for non encoder-decode models. So far, the crate stays rather close to the Python implementation (https://github.com/huggingface/transformers), and I would love to hear the Hugging Face's perspective on this issue (the handling of min_length
is identical as far as I can tell). Could you please raise an issue over there and ask for their view on this?
Also what happens when the length of inputs exceeds max length
while current_length < gen_opt.max_length {
Wouldn't this while loop not even run, and therefore no output at all?
Yes - if the input provided exceeds the maximum length, no text is generated. For the conversation pipeline this is avoided via the method: https://github.com/guillaume-be/rust-bert/blob/f7da9dcee42c6842499917b9d4921f380c2731a9/src/pipelines/conversation.rs#L773
This clears the first elements of the history to make space for a new input. I realize now that this is not perfect as we want to make sure we leave enough space for at least min_length
, and ideally a bit more. I will work on updating this.
Is there are method to reliably find the length of input that would go into the model? If so the solution in python might be to update the min and max length before calling get_responses from the client end
Opened upstream issue in huggingface/transformers#7800
Fixed by https://github.com/guillaume-be/rust-bert/pull/296