albert
albert copied to clipboard
how to do inference for SQuAD style QA task using ALBERT?
I am developing a question answering system using this. But it is quite hard to find code to do inference part. Can someone help me? What I need is a function that predict an answer to a question. function inputs are question and context.
Hey, I'm looking for the same actually. Did you find any solution?
This code worked for me:
import tensorflow as tf
import numpy as np
import squad_utils
from transformers import AlbertTokenizer
max_seq_length=384;
def output_fn(feature):
pass
read_file=squad_utils.read_squad_examples("path-to-file-that-contains-SQuaD-example",False)
tokenizer=AlbertTokenizer.from_pretrained("albert-base-v2")
features=squad_utils.convert_examples_to_features(read_file,tokenizer,384,128,64,False,output_fn,True)
tokenized_values = tokenizer(read_file[0].question_text, read_file[0].paragraph_text, max_length=max_seq_length, truncation=True, padding='max_length', return_tensors='np')
interpreter = tf.lite.Interpreter(model_path="path-to-model")
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
input_ids = features.input_ids
input_mask = features.input_mask
segment_ids = features.segment_ids
input_ids = np.array(input_ids, dtype=np.int32).reshape((1,384))
input_mask = np.array(input_mask, dtype=np.int32).reshape((1,384))
segment_ids = np.array(segment_ids, dtype=np.int32).reshape((1,384))
interpreter.set_tensor(input_details[0]["index"], input_ids)
interpreter.set_tensor(input_details[1]["index"], input_mask)
interpreter.set_tensor(input_details[2]["index"], segment_ids)
interpreter.invoke()
end_logits = interpreter.get_tensor(output_details[0]["index"])[0]
start_logits = interpreter.get_tensor(output_details[1]["index"])[0]
start_positions=np.argmax(start_logits)
end_positions=np.argmax(end_logits)
answer_tokens = tokenized_values["input_ids"][0][start_positions:end_positions + 1]
answer = tokenizer.decode(answer_tokens, skip_special_tokens=True)
print("Predicted Answer:", answer)