keras-nlp icon indicating copy to clipboard operation
keras-nlp copied to clipboard

Polishing T5 inference

Open monatis opened this issue 1 year ago • 1 comments

Hi, as it's reported in several issues (#1271, #1413), t5 still lacks some of workflows. Particularly, I'm trying to optimize T5 conditional generation. I started by porting code from BartSeq2SeqLM, but one immediate thing that caught my attention is that T5 uses its own MHA implementation which lacks the kv cache functionality implemented in CachedMultiHeadAttention. This can be achieved in two ways:

  1. Add rel_attn_bias support to CachedMultiHeadAttention, or
  2. Add kv cache support to T5MultiHeadAttention. I'm also planning to upstream what I came up with. The question is, which one would you prefer, and which one do you think would be easier to hack? I'm more for the option 2, but is there anything I'm missing?

monatis avatar Jul 18 '24 12:07 monatis