vllm
vllm copied to clipboard
[Hardware][TPU][V1] Multi-LoRA implementation for the V1 TPU backend
This PR adds a Multi-LoRA implementation that works on the TPU backend, extending the work done in https://github.com/vllm-project/vllm/pull/11100, and supercedes https://github.com/vllm-project/vllm/pull/12623/. It has a functional but unoptimised Pallas kernel implementation for the bgmv kernel.
👋 Hi! Thank you for contributing to the vLLM project.
💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.
Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.
To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.
🚀
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @Akshat-Tripathi.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
This PR is waiting on https://github.com/vllm-project/vllm/pull/14310 to be merged
Hi @NickLucche this PR should be ready for you to review now
Thanks for the mega work here! Left some comments, I am mostly interested in the jax/torch xla integration as this may be useful in the future, can we just pre-compile it into xla and use it at runtime with no overhead?
Ah what do you mean? I'm not to familiar with the torch_xla compilation process, is there a way I can see the compilation graphs?
I generally use https://pytorch.org/xla/release/r2.5/debug.html#common-debugging-environment-variables-combinations but I am not sure this is what you're looking for.
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @Akshat-Tripathi.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
I generally use https://pytorch.org/xla/release/r2.5/debug.html#common-debugging-environment-variables-combinations but I am not sure this is what you're looking for.
Thanks! I looked at the compilation logs with PT_XLA_DEBUG=2 and it looks like it gets compiled in with the main graph? There's an extra mark_step happening in the lora setup, but I don't see it show up in any performance traces.
Looks good, I'll ask for more eyes on this. Thanks for your work!
I am getting
ValueError: invalid literal for int() with base 10: 'A'runningpython -m pytest -s tests/tpu/test_lora.py. PTAL
What are your jax/libtpu/torch_xla versions? The output seems to be somewhat different depending on the version
Looks good, I'll ask for more eyes on this. Thanks for your work!
I am getting
ValueError: invalid literal for int() with base 10: 'A'runningpython -m pytest -s tests/tpu/test_lora.py. PTAL
What are your jax/libtpu/torch_xla versions? The output seems to be somewhat different depending on the version
Name: torch
Version: 2.8.0
Location: /home/nick/vllm/.venv/lib/python3.11/site-packages
Requires: filelock, fsspec, jinja2, networkx, sympy, typing-extensions
Required-by: compressed-tensors, outlines, xgrammar
---
Name: torch-xla
Version: 2.8.0+git4190fc0
Location: /home/nick/vllm/.venv/lib/python3.11/site-packages
Requires: absl-py, numpy, pyyaml, requests
Required-by:
Name: jax
Version: 0.5.2.dev20250303
Location: /home/nick/vllm/.venv/lib/python3.11/site-packages
Requires: jaxlib, ml-dtypes, numpy, opt-einsum, scipy
Required-by:
Name: torch Version: 2.8.0 Location: /home/nick/vllm/.venv/lib/python3.11/site-packages Requires: filelock, fsspec, jinja2, networkx, sympy, typing-extensions Required-by: compressed-tensors, outlines, xgrammar --- Name: torch-xla Version: 2.8.0+git4190fc0 Location: /home/nick/vllm/.venv/lib/python3.11/site-packages Requires: absl-py, numpy, pyyaml, requests Required-by: Name: jax Version: 0.5.2.dev20250303 Location: /home/nick/vllm/.venv/lib/python3.11/site-packages Requires: jaxlib, ml-dtypes, numpy, opt-einsum, scipy Required-by:
Ok this looks right, and I can reproduce the bug.
For context, I've trained a few adapters to always produce specific output to the question "What is 1+1?`". So the first token from the ith adapter should always be i. What's happening here is that the first adapter isn't producing "1" anymore, instead it's something like "A) 1", but the other adapters are producing the expected value. So the adapters are still working but the output is slightly different to what it was before.
I'm going to further finetune the first adapter to try and get rid of this problem.
This looks like it might turn it into a flaky test as we're going to update those libs quite frequently along with the attn kernels. Perhaps we should test something simpler like asserting the output exists (ie it didn't crash running the request)
This looks like it might turn it into a flaky test as we're going to update those libs quite frequently along with the attn kernels. Perhaps we should test something simpler like asserting the output exists (ie it didn't crash running the request)
We could, but wouldn't that also pass if the loras aren't applied or aren't applied properly? Earlier I had a bug where the outputs with LoRA were corrupted because of a bug in the kernel, so that's why I wanted to assert the that output matched.
Can we get logits out? I can test that the logits for token i are higher than the default value?
We could, but wouldn't that also pass if the loras aren't applied or aren't applied properly?
Wouldn't that also be the case anyway if lora weren't applied and the base model just replied correctly?
Can we get logits out?
Not yet on TPU.
Wouldn't that also be the case anyway if lora weren't applied and the base model just replied 2?
Yeah that's also true.
Not yet on TPU.
Ah ok. I can weaken the test slightly by making sure that the lora answer comes before the non lora answer. How does that sound?
@Mergifyio refresh
refresh
✅ Pull request refreshed
Thanks a lot for your hard work and patience with this PR @Akshat-Tripathi ! I think we may want to add the e2e test to the TPU ones here https://github.com/vllm-project/vllm/blob/main/.buildkite/run-tpu-v1-test.sh.
Also, given this PR introduces a big feature, perhaps we could back that up with some numbers from a benchmark run. Let me know if you need help with that.
Thanks @NickLucche! I'll add the test to the CI now, but I don't think this PR is the best to benchmark since it's completely unoptimised. I've been working on an optimised version in parallel here which is much faster.
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @Akshat-Tripathi.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
@NickLucche I've put those optimisations in this draft PR: https://github.com/vllm-project/vllm/pull/15655
There's another issue that needs confirmation: whether full sharded LoRA and add_bias are supported. If not supported, please refer to: https://github.com/vllm-project/vllm/blob/main/vllm/worker/hpu_model_runner.py#L704-L707
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @Akshat-Tripathi.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
There's another issue that needs confirmation: whether full sharded LoRA and
add_biasare supported. If not supported, please refer to: https://github.com/vllm-project/vllm/blob/main/vllm/worker/hpu_model_runner.py#L704-L707
Both of them should be supported. How can I test whether fully sharded LoRA works? Would running a Llama 3.1 70B tensor sliced across a few TPUs do the trick?
See: https://github.com/vllm-project/vllm/blob/main/tests/lora/test_llama_tp.py#L164
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @Akshat-Tripathi.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
There's another issue that needs confirmation: whether full sharded LoRA and
add_biasare supported. If not supported, please refer to: https://github.com/vllm-project/vllm/blob/main/vllm/worker/hpu_model_runner.py#L704-L707
Hi @jeejeelee that was a good spot. I've just added them in now. The test won't pass since the expected outputs are different, but I've been seeing sensible looking outputs on my side.
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @Akshat-Tripathi.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @Akshat-Tripathi.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
@Akshat-Tripathi thanks for your awesome contribution!
Left some comments. In addition,
- Could you instruct me how to run the test locally?
- I don't expect this PR solve every issue of multi-lora support on TPU, it's already a very large one. So could you add the TODO list in the description of the PR?