transformer-lm
transformer-lm copied to clipboard
Pytorch: Speed up get_log_probs function
Hi lopuhin.
Thanks for your sharing.
I used your model to predict the next word, but I found the forecast speed relatively slow, probably because of the lm.inference.get_log_probs function to predict the probability of all both words in sentences. Meanwhile, the problem of predicting the next word only requires the probability of the last word.
I found the forecast speed relatively slow
Thanks for feedback. Did you find it slow compared to other similar models / implementations?
Meanwhile, the problem of predicting the next word only requires the probability of the last word.
Right. But we still need to process all previous words. I see that we could avoid doing softmax for all but the last word, not sure how much difference will it bring.
FWIW there is a big speedup in text generation here 4c18649391826c2b1590e722b9559fa5f72ced8e - this speeds up generation of multiple tokens, while the single token generation has the same speed.