transformer-lm
transformer-lm copied to clipboard
Stand-alone text generation and scoring scripts
Hi lopuhin! I using your code very nice in training and generate sample in training, please write code for load model, weights... and generate text from other python file. Thanks you!
Thank you @binhvq 👍 This is something I would love to have. It's possible to do that by excluding some code from the training script, as it generates samples during training, but I would love to have this feature stand-alone.
But right now I'm moving to Tensorflow 2.0 to be able to support multi-GPU training easily, plan to get back to easy text scoring and generation once it's done.
@binhvq keeping it open as this is something to be done, I hope you don't mind the title change.
Hi lopuhin! I was write code for load model and generate text. It's very simple.
Thanks
Scoring is implemented in https://github.com/lopuhin/transformer-lm/blob/master/lm/inference.py
@binhvq @lopuhin Can you please share text generate script?
@virgulvirgul I don't have it, but if you have log probabilities for next token, you can take an exponent to get real probabilities, then select say top 40 of them, and then select next token using this probabilities (e.g. passing probabilities into p
parameter of np.random.choice
).
implemented this approach, works great for my model so far:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from pathlib import Path
from lm import inference
import numpy as np
MODEL_PATH = Path('run-root')
TOKENS_TO_GENERATE = 32
TOP_K = 8
mw = inference.ModelWrapper.load(MODEL_PATH)
txt = "Die Forschung an der künstlichen Intelligenz"
tokens = mw.tokenize(txt)
for i in range(TOKENS_TO_GENERATE):
# generate TOP_K potential next tokens
ntk = mw.get_next_top_k(tokens, TOP_K)
# convert log probs to real probs
logprobs = np.array(list(map(lambda a: a[0], ntk)))
probs = np.exp(logprobs) / np.exp(logprobs).sum()
# pick next token randomly according to probs distribution
next_token_n = np.random.choice(TOP_K, p=probs)
next_token = ntk[next_token_n][1]
print (next_token)
tokens.append(next_token)
print(mw.sp_model.DecodePieces(tokens))