vllm icon indicating copy to clipboard operation
vllm copied to clipboard

[Hardware][TPU][V1] Multi-LoRA implementation for the V1 TPU backend

Open Akshat-Tripathi opened this issue 8 months ago • 38 comments

This PR adds a Multi-LoRA implementation that works on the TPU backend, extending the work done in https://github.com/vllm-project/vllm/pull/11100, and supercedes https://github.com/vllm-project/vllm/pull/12623/. It has a functional but unoptimised Pallas kernel implementation for the bgmv kernel.

Akshat-Tripathi avatar Mar 04 '25 21:03 Akshat-Tripathi

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

github-actions[bot] avatar Mar 04 '25 21:03 github-actions[bot]

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @Akshat-Tripathi.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Mar 05 '25 08:03 mergify[bot]

This PR is waiting on https://github.com/vllm-project/vllm/pull/14310 to be merged

Akshat-Tripathi avatar Mar 06 '25 14:03 Akshat-Tripathi

Hi @NickLucche this PR should be ready for you to review now

Akshat-Tripathi avatar Mar 10 '25 12:03 Akshat-Tripathi

Thanks for the mega work here! Left some comments, I am mostly interested in the jax/torch xla integration as this may be useful in the future, can we just pre-compile it into xla and use it at runtime with no overhead?

Ah what do you mean? I'm not to familiar with the torch_xla compilation process, is there a way I can see the compilation graphs?

Akshat-Tripathi avatar Mar 11 '25 10:03 Akshat-Tripathi

I generally use https://pytorch.org/xla/release/r2.5/debug.html#common-debugging-environment-variables-combinations but I am not sure this is what you're looking for.

NickLucche avatar Mar 15 '25 17:03 NickLucche

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @Akshat-Tripathi.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Mar 15 '25 17:03 mergify[bot]

I generally use https://pytorch.org/xla/release/r2.5/debug.html#common-debugging-environment-variables-combinations but I am not sure this is what you're looking for.

Thanks! I looked at the compilation logs with PT_XLA_DEBUG=2 and it looks like it gets compiled in with the main graph? There's an extra mark_step happening in the lora setup, but I don't see it show up in any performance traces.

Akshat-Tripathi avatar Mar 20 '25 18:03 Akshat-Tripathi

Looks good, I'll ask for more eyes on this. Thanks for your work!

I am getting ValueError: invalid literal for int() with base 10: 'A' running python -m pytest -s tests/tpu/test_lora.py. PTAL

What are your jax/libtpu/torch_xla versions? The output seems to be somewhat different depending on the version

Akshat-Tripathi avatar Mar 21 '25 14:03 Akshat-Tripathi

Looks good, I'll ask for more eyes on this. Thanks for your work!

I am getting ValueError: invalid literal for int() with base 10: 'A' running python -m pytest -s tests/tpu/test_lora.py. PTAL

What are your jax/libtpu/torch_xla versions? The output seems to be somewhat different depending on the version

Akshat-Tripathi avatar Mar 21 '25 14:03 Akshat-Tripathi

Name: torch
Version: 2.8.0
Location: /home/nick/vllm/.venv/lib/python3.11/site-packages
Requires: filelock, fsspec, jinja2, networkx, sympy, typing-extensions
Required-by: compressed-tensors, outlines, xgrammar
---
Name: torch-xla
Version: 2.8.0+git4190fc0
Location: /home/nick/vllm/.venv/lib/python3.11/site-packages
Requires: absl-py, numpy, pyyaml, requests
Required-by:
Name: jax
Version: 0.5.2.dev20250303
Location: /home/nick/vllm/.venv/lib/python3.11/site-packages
Requires: jaxlib, ml-dtypes, numpy, opt-einsum, scipy
Required-by:

NickLucche avatar Mar 21 '25 15:03 NickLucche

Name: torch
Version: 2.8.0
Location: /home/nick/vllm/.venv/lib/python3.11/site-packages
Requires: filelock, fsspec, jinja2, networkx, sympy, typing-extensions
Required-by: compressed-tensors, outlines, xgrammar
---
Name: torch-xla
Version: 2.8.0+git4190fc0
Location: /home/nick/vllm/.venv/lib/python3.11/site-packages
Requires: absl-py, numpy, pyyaml, requests
Required-by:
Name: jax
Version: 0.5.2.dev20250303
Location: /home/nick/vllm/.venv/lib/python3.11/site-packages
Requires: jaxlib, ml-dtypes, numpy, opt-einsum, scipy
Required-by:

Ok this looks right, and I can reproduce the bug.

For context, I've trained a few adapters to always produce specific output to the question "What is 1+1?`". So the first token from the ith adapter should always be i. What's happening here is that the first adapter isn't producing "1" anymore, instead it's something like "A) 1", but the other adapters are producing the expected value. So the adapters are still working but the output is slightly different to what it was before.

I'm going to further finetune the first adapter to try and get rid of this problem.

Akshat-Tripathi avatar Mar 24 '25 17:03 Akshat-Tripathi

This looks like it might turn it into a flaky test as we're going to update those libs quite frequently along with the attn kernels. Perhaps we should test something simpler like asserting the output exists (ie it didn't crash running the request)

NickLucche avatar Mar 24 '25 17:03 NickLucche

This looks like it might turn it into a flaky test as we're going to update those libs quite frequently along with the attn kernels. Perhaps we should test something simpler like asserting the output exists (ie it didn't crash running the request)

We could, but wouldn't that also pass if the loras aren't applied or aren't applied properly? Earlier I had a bug where the outputs with LoRA were corrupted because of a bug in the kernel, so that's why I wanted to assert the that output matched.

Can we get logits out? I can test that the logits for token i are higher than the default value?

Akshat-Tripathi avatar Mar 24 '25 17:03 Akshat-Tripathi

We could, but wouldn't that also pass if the loras aren't applied or aren't applied properly?

Wouldn't that also be the case anyway if lora weren't applied and the base model just replied correctly?

Can we get logits out?

Not yet on TPU.

NickLucche avatar Mar 24 '25 17:03 NickLucche

Wouldn't that also be the case anyway if lora weren't applied and the base model just replied 2?

Yeah that's also true.

Not yet on TPU.

Ah ok. I can weaken the test slightly by making sure that the lora answer comes before the non lora answer. How does that sound?

Akshat-Tripathi avatar Mar 24 '25 18:03 Akshat-Tripathi

@Mergifyio refresh

Akshat-Tripathi avatar Mar 24 '25 18:03 Akshat-Tripathi

refresh

✅ Pull request refreshed

mergify[bot] avatar Mar 24 '25 18:03 mergify[bot]

Thanks a lot for your hard work and patience with this PR @Akshat-Tripathi ! I think we may want to add the e2e test to the TPU ones here https://github.com/vllm-project/vllm/blob/main/.buildkite/run-tpu-v1-test.sh.

Also, given this PR introduces a big feature, perhaps we could back that up with some numbers from a benchmark run. Let me know if you need help with that.

Thanks @NickLucche! I'll add the test to the CI now, but I don't think this PR is the best to benchmark since it's completely unoptimised. I've been working on an optimised version in parallel here which is much faster.

Akshat-Tripathi avatar Mar 27 '25 21:03 Akshat-Tripathi

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @Akshat-Tripathi.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Mar 27 '25 21:03 mergify[bot]

@NickLucche I've put those optimisations in this draft PR: https://github.com/vllm-project/vllm/pull/15655

Akshat-Tripathi avatar Mar 27 '25 23:03 Akshat-Tripathi

There's another issue that needs confirmation: whether full sharded LoRA and add_bias are supported. If not supported, please refer to: https://github.com/vllm-project/vllm/blob/main/vllm/worker/hpu_model_runner.py#L704-L707

jeejeelee avatar Mar 28 '25 02:03 jeejeelee

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @Akshat-Tripathi.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Mar 29 '25 04:03 mergify[bot]

There's another issue that needs confirmation: whether full sharded LoRA and add_bias are supported. If not supported, please refer to: https://github.com/vllm-project/vllm/blob/main/vllm/worker/hpu_model_runner.py#L704-L707

Both of them should be supported. How can I test whether fully sharded LoRA works? Would running a Llama 3.1 70B tensor sliced across a few TPUs do the trick?

Akshat-Tripathi avatar Mar 31 '25 19:03 Akshat-Tripathi

See: https://github.com/vllm-project/vllm/blob/main/tests/lora/test_llama_tp.py#L164

jeejeelee avatar Apr 01 '25 01:04 jeejeelee

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @Akshat-Tripathi.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Apr 03 '25 10:04 mergify[bot]

There's another issue that needs confirmation: whether full sharded LoRA and add_bias are supported. If not supported, please refer to: https://github.com/vllm-project/vllm/blob/main/vllm/worker/hpu_model_runner.py#L704-L707

Hi @jeejeelee that was a good spot. I've just added them in now. The test won't pass since the expected outputs are different, but I've been seeing sensible looking outputs on my side.

Akshat-Tripathi avatar Apr 07 '25 14:04 Akshat-Tripathi

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @Akshat-Tripathi.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Apr 09 '25 10:04 mergify[bot]

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @Akshat-Tripathi.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Apr 10 '25 00:04 mergify[bot]

@Akshat-Tripathi thanks for your awesome contribution!

Left some comments. In addition,

  1. Could you instruct me how to run the test locally?
  2. I don't expect this PR solve every issue of multi-lora support on TPU, it's already a very large one. So could you add the TODO list in the description of the PR?

yaochengji avatar Apr 10 '25 22:04 yaochengji