maxtext
maxtext copied to clipboard
Jax implementation of creating unified_lora_params and computing xAB for a batch of decode requests.
Description
Jax implementation of creating unified_lora_params and computing xAB for a batch of decode requests.
- Created a cache enough to hold the LoRA adapter params for each slot. Here the tensors has an extra dimension(0) which is size = #slots. From JetStream we can insert the LoRA adapter params (A and B) into the cache.
- As each slot could have different parameters, multiple different adapters can be used in the same decoding batch.
- This cache is further utilized in the Attention module to calculate xAB for query, key and value for all the currently processing decoding requests.
Before submitting this PR, please make sure (put X in square brackets):
- [x] I have performed a self-review of my code.
- [x] I have necessary comments in my code, particularly in hard-to-understand areas.
- [x] I have run end-to-end tests tests and provided workload links above if applicable.