lorax
lorax copied to clipboard
(WIP) Support targeting the embedding layer for LoRA
What does this PR do?
- Re-organize the code in BatchLoraWeights.load. This function was a bit hard to understand as there were multiple list comprehensions with almost same looping logic. So, merged all of them into two loops for improved clarity. @tgaddair Can you confirm if this looks good? I can revert back to the original code in case this change can cause problems.
- (WIP) Support embedding layer as a target module. This is mostly done except multi-gpu inference.
@tgaddair I am pushing a partially done commit that supports embedding layer loras.
- [x] Similar to HF implementation,
lora_Ais used for embedding lookup whilelora_Bis multiplied - [x] Prevents
lora_Atranspose when in BGMV mode - [x] Contains two implementations to replace two kernels: SGMV and BGMV
- [ ] Both are implemented with for loops. How can we optimize them?
- [ ] Cannot handle multi-GPU. I will need some help understanding sharding in LoRAX as I found it confusing. :(
- [ ] Tested crudely by comparing with generation from HF, but need to add a proper test case.