lorax icon indicating copy to clipboard operation
lorax copied to clipboard

Fix for the LM_HEAD issue

Open ajtejankar opened this issue 3 months ago • 0 comments

(WIP) Fix for the LM_HEAD issue

Root Cause. The error is caused by incorrect segments passed to the lora_b_sgmv kernel during the prefill stage. This happens because we do not want to forward all the tokens in the prompt through lm_head and associated adapters. The goal is to save compute and memory by not having to forward the tokens that are not involved in generating the next token. Only the last token is needed for this purpose. However, doing this changes the shape of the net batch size (batch size times number of tokens) that the lora kernel sees, but the segment start and end tensors are not changed. These tensors are used to slice/segment the batch in the kernel. Hence, keeping them unchanged is incorrect. Additionally, the resulting error from this problem not reported correctly since the kernel has a catch-call condition that reports a generic kernel not found for the dtype message. My guess is that it's an out-of-bounds memory access, but confirming this would require changing the kernel. This is out of the scope of this investigation.

Description of the Fix. The fix simply goes over the adapters for lm_head and adjusts the segment start and end tensors to correctly point to the segments in the batch. For now, this is done only during the prefill stage and when a LoRA adapter for the lm_head is present.

Tests. TBD

Fixes #163

Before submitting

  • [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [ ] Was this discussed/approved via a Github issue or the discord / slack channel? Please add a link to it if that's the case.
  • [ ] Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.

ajtejankar avatar May 18 '24 02:05 ajtejankar