lorax icon indicating copy to clipboard operation
lorax copied to clipboard

(WIP) Support targeting the embedding layer for LoRA

Open ajtejankar opened this issue 1 year ago • 1 comments

What does this PR do?

  1. Re-organize the code in BatchLoraWeights.load. This function was a bit hard to understand as there were multiple list comprehensions with almost same looping logic. So, merged all of them into two loops for improved clarity. @tgaddair Can you confirm if this looks good? I can revert back to the original code in case this change can cause problems.
  2. (WIP) Support embedding layer as a target module. This is mostly done except multi-gpu inference.

ajtejankar avatar Jun 06 '24 21:06 ajtejankar

@tgaddair I am pushing a partially done commit that supports embedding layer loras.

  • [x] Similar to HF implementation, lora_A is used for embedding lookup while lora_B is multiplied
  • [x] Prevents lora_A transpose when in BGMV mode
  • [x] Contains two implementations to replace two kernels: SGMV and BGMV
  • [ ] Both are implemented with for loops. How can we optimize them?
  • [ ] Cannot handle multi-GPU. I will need some help understanding sharding in LoRAX as I found it confusing. :(
  • [ ] Tested crudely by comparing with generation from HF, but need to add a proper test case.

ajtejankar avatar Jun 08 '24 01:06 ajtejankar