maxtext
maxtext copied to clipboard
Create a user friendly inference demo
This is a feature request.
I like maxtext
because it is very customizable and efficient for training.
The main issue I’m having is hacking away an inference function. The code is quite complex so not straightforward to do.
The simple decode.py
works but it seems mainly experimental development for streaming.
I think streaming will be really cool, but we would also benefit from an easy model.generate(input_ids, attention_mask, params)
function:
- it should allow prefill based on the length of
input_ids
(user responsibility to try to supply not too many shapes to avoid recompilation) - it should allow batch input, with left padding to support different input length
- should be compilable with
jit
/pjit
- allow a few common sampling strategy: greedy, sample (with temperature, top k, top p), beam search
- allow being used without a separate engine/service in case we want to make it part of a larger function that includes multiple models
This PR looked interesting: https://github.com/google/maxtext/pull/402 I think that it was mainly for benchmarking though as it didn’t stop when the entire batch was eos but had a nice prefill functionality.