Add prefill step for casual lm
Now, CasualLM in Keras_hub is different from ordinary LLMs, as it doesn't have a prefill step. This results in a lot of unnecessary waste for long prompt models. So I think Keras_hub needs to consider introducing a prefill step for casualLM. But this will be a major update, and we need to make different modifications to the current large number of models. Can the Keras team provide a relevant advancement plan? So that our community can submit contributions more conveniently.
We do have a prefill step, it's just not explicitly called out as such. See for example https://github.com/keras-team/keras-hub/blob/25c9062c5f9a25fade16094cd21d2545125168ec/keras_hub/src/models/gemma/gemma_causal_lm.py#L250 We will prefill for the entire batch and sequence in one go. Fast but high memory.
We are missing a configurable prefill chunk size, which would be quite useful for longer generations (turn the prefill size down or turn prefill off entirely to save VRAM). It could also be useful to expose a prefill and decode steps that are individually callable, to allow for more custom generation flows.
For both of these (or other generation improvements) it would probably make sense to prototype something for an individual model, but we would need to do so with an implementation that is relatively clean, and works on all backends (notably it would need to work on Jax with it's stateless and static shape requirements). If we have a design we think is good for, say, Gemma, we could then think on how to generalize it.
However, given that generation hits compilation which varies for torch, jax, and tensorflow, this will be a technically complex place even to prototype.
We do have a prefill step, it's just not explicitly called out as such. See for example
keras-hub/keras_hub/src/models/gemma/gemma_causal_lm.py
Line 250 in 25c9062
hidden_states, cache = self._build_cache(token_ids) We will prefill for the entire batch and sequence in one go. Fast but high memory. We are missing a configurable prefill chunk size, which would be quite useful for longer generations (turn the prefill size down or turn prefill off entirely to save VRAM). It could also be useful to expose a prefill and decode steps that are individually callable, to allow for more custom generation flows.
For both of these (or other generation improvements) it would probably make sense to prototype something for an individual model, but we would need to do so with an implementation that is relatively clean, and works on all backends (notably it would need to work on Jax with it's stateless and static shape requirements). If we have a design we think is good for, say, Gemma, we could then think on how to generalize it.
However, given that generation hits compilation which varies for torch, jax, and tensorflow, this will be a technically complex place even to prototype.
I think we can add an option for left padding. This can reduce some of the performance loss of batch inference.
Anyway, the mainstream LLMs such as Qwen are based on left padding in the training phase. Therefore, we can ignore the situation of old models such as GPT2 for now.
We do have a prefill step, it's just not explicitly called out as such. See for example
keras-hub/keras_hub/src/models/gemma/gemma_causal_lm.py
Line 250 in 25c9062
hidden_states, cache = self._build_cache(token_ids) We will prefill for the entire batch and sequence in one go. Fast but high memory. We are missing a configurable prefill chunk size, which would be quite useful for longer generations (turn the prefill size down or turn prefill off entirely to save VRAM). It could also be useful to expose a prefill and decode steps that are individually callable, to allow for more custom generation flows.
For both of these (or other generation improvements) it would probably make sense to prototype something for an individual model, but we would need to do so with an implementation that is relatively clean, and works on all backends (notably it would need to work on Jax with it's stateless and static shape requirements). If we have a design we think is good for, say, Gemma, we could then think on how to generalize it.
However, given that generation hits compilation which varies for torch, jax, and tensorflow, this will be a technically complex place even to prototype.
If you think we can add a left-pad option for models that use left-pad in the pre-training phase, I will submit a new PR in the near future.