aitextgen
aitextgen copied to clipboard
How do you specify model with CLI interface?
I can't seem to find any explanation of flags for the CLI interface, particularly, how to specify the model.
I looked through aitextgen.py for some hints. Using "--tf_gpt2" kind of works:
aitextgen generate --tf_gpt2=124M --prompt "I believe in unicorns because" --to_file False
The model is downloaded, but then execution bombs out with the error:
ValueError: The following `model_kwargs` are not used by the model: ['tf_gpt2'] (note: typos in the generate arguments will also show up in this list)
I think what is happening is that aitextgen recognises the flag (for example, --tf_gpt2=355M downloads the 355M model as expected, so it's not just using a default), but then passes all CLI parms unchanged to transformers, which errors out on that unknown arg?
Is it as simple as the CLI interface voiding kwargs variables like tf_gpt2 before calling transformers generate(), or does more need to be done?
I'm not a Python programmer, so apologies for the clunky understanding. Thank you.
Edit: hacking in del kwargs['tf_gpt2'] just before the call to self.model.generate seems to have done the trick, but I'm unsure whether this is the best approach.
try:
kwargs['tf_gpt2']
except KeyError:
pass
else:
del kwargs['tf_gpt2']