keras-nlp icon indicating copy to clipboard operation
keras-nlp copied to clipboard

Add prefill step for casual lm

Open pass-lin opened this issue 6 months ago • 3 comments

Now, CasualLM in Keras_hub is different from ordinary LLMs, as it doesn't have a prefill step. This results in a lot of unnecessary waste for long prompt models. So I think Keras_hub needs to consider introducing a prefill step for casualLM. But this will be a major update, and we need to make different modifications to the current large number of models. Can the Keras team provide a relevant advancement plan? So that our community can submit contributions more conveniently.

pass-lin avatar Jun 14 '25 11:06 pass-lin

We do have a prefill step, it's just not explicitly called out as such. See for example https://github.com/keras-team/keras-hub/blob/25c9062c5f9a25fade16094cd21d2545125168ec/keras_hub/src/models/gemma/gemma_causal_lm.py#L250 We will prefill for the entire batch and sequence in one go. Fast but high memory.

We are missing a configurable prefill chunk size, which would be quite useful for longer generations (turn the prefill size down or turn prefill off entirely to save VRAM). It could also be useful to expose a prefill and decode steps that are individually callable, to allow for more custom generation flows.

For both of these (or other generation improvements) it would probably make sense to prototype something for an individual model, but we would need to do so with an implementation that is relatively clean, and works on all backends (notably it would need to work on Jax with it's stateless and static shape requirements). If we have a design we think is good for, say, Gemma, we could then think on how to generalize it.

However, given that generation hits compilation which varies for torch, jax, and tensorflow, this will be a technically complex place even to prototype.

mattdangerw avatar Jun 26 '25 00:06 mattdangerw

We do have a prefill step, it's just not explicitly called out as such. See for example

keras-hub/keras_hub/src/models/gemma/gemma_causal_lm.py

Line 250 in 25c9062

hidden_states, cache = self._build_cache(token_ids) We will prefill for the entire batch and sequence in one go. Fast but high memory. We are missing a configurable prefill chunk size, which would be quite useful for longer generations (turn the prefill size down or turn prefill off entirely to save VRAM). It could also be useful to expose a prefill and decode steps that are individually callable, to allow for more custom generation flows.

For both of these (or other generation improvements) it would probably make sense to prototype something for an individual model, but we would need to do so with an implementation that is relatively clean, and works on all backends (notably it would need to work on Jax with it's stateless and static shape requirements). If we have a design we think is good for, say, Gemma, we could then think on how to generalize it.

However, given that generation hits compilation which varies for torch, jax, and tensorflow, this will be a technically complex place even to prototype.

I think we can add an option for left padding. This can reduce some of the performance loss of batch inference.

Anyway, the mainstream LLMs such as Qwen are based on left padding in the training phase. Therefore, we can ignore the situation of old models such as GPT2 for now.

pass-lin avatar Jun 26 '25 04:06 pass-lin

We do have a prefill step, it's just not explicitly called out as such. See for example

keras-hub/keras_hub/src/models/gemma/gemma_causal_lm.py

Line 250 in 25c9062

hidden_states, cache = self._build_cache(token_ids) We will prefill for the entire batch and sequence in one go. Fast but high memory. We are missing a configurable prefill chunk size, which would be quite useful for longer generations (turn the prefill size down or turn prefill off entirely to save VRAM). It could also be useful to expose a prefill and decode steps that are individually callable, to allow for more custom generation flows.

For both of these (or other generation improvements) it would probably make sense to prototype something for an individual model, but we would need to do so with an implementation that is relatively clean, and works on all backends (notably it would need to work on Jax with it's stateless and static shape requirements). If we have a design we think is good for, say, Gemma, we could then think on how to generalize it.

However, given that generation hits compilation which varies for torch, jax, and tensorflow, this will be a technically complex place even to prototype.

If you think we can add a left-pad option for models that use left-pad in the pre-training phase, I will submit a new PR in the near future.

pass-lin avatar Jun 26 '25 16:06 pass-lin