End-to-End Question Generation directly from /trasnformers library
I'm trying to generate questions using "valhalla/t5-base-e2e-qg" End-to-End model directly from the /transformers library.
My code:
# Item No. 1
text1 = "Chicken Samosa costs $6.95"
text2 = "Chicken Samosa has Crispy Wrappers filled with Spiced Chicken."
text3 = "Chicken Samosa is served with Cilantro Dipping Sauce."
t5_e2es_tokenizer = AutoTokenizer.from_pretrained("valhalla/t5-small-e2e-qg")
t5_e2es_model = AutoModelWithLMHead.from_pretrained("valhalla/t5-small-e2e-qg")
enc = t5_e2es_tokenizer([text2], return_tensors="pt")
input_ids = enc["input_ids"]
attention_mask = enc["attention_mask"] # necessary if padding is enabled so the model won't attend pad tokens
t5e2es_tokens = t5_e2es_model.generate(input_ids=input_ids, attention_mask=attention_mask
, num_beams=10, num_return_sequences=3
)
out_ = []
for token in t5e2es_tokens:
print(t5_e2es_tokenizer.decode(token))
out_.append(t5_e2es_tokenizer.decode(token))
This is the output I get:
Chicken Samosa has Crisp Wrappers filled with Spiced Chicken Samosa has Chicken Samosa has Crisp Wrappers filled with Spice Chicken Samosa has Crisp Chicken Samosa has Crisp Wrappers filled with Spiced Chicken. Crisp Wrappers filled
Where did I go wrong? I thought I just had to follow the same procedure which I did for single-task qg model.
Hi. Any comments on this? @patil-suraj
Hi @vidyap-xgboost
For e2e models, we need different generate parameters and the text needs to be prefixed with generate questions: .
Also </s> should be added at the end for all of these T5 models.
you can try this
text2 = "generate questions: Chicken Samosa has Crispy Wrappers filled with Spiced Chicken. </s>"
enc = t5_e2es_tokenizer([text2], return_tensors="pt")
input_ids = enc["input_ids"]
attention_mask = enc["attention_mask"] # necessary if padding is enabled so the model won't attend pad tokens
t5e2es_tokens = t5_e2es_model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
num_beams=4,
max_length=256,
no_repeat_ngram_size=3,
length_penalty=1.5,
early_stopping=True,
)
out_ = []
for token in t5e2es_tokens:
print(t5_e2es_tokenizer.decode(token))
out_.append(t5_e2es_tokenizer.decode(token))
=> Chicken Samosa has Crispy Wrappers filled with what? <sep> What spiced chicken is filled with Spiced Chicken? <sep>