vllm
vllm copied to clipboard
[Hardware][TPU][V1] Multi-LoRA Optimisations for the V1 TPU backend
Summary
This PR optimises the Multi-LoRA implementation from https://github.com/vllm-project/vllm/pull/14238/. This one should be merged in after it.
This includes several kernel optimisations:
- Block size tuning 2bb886899f540463a43c47ec81ae77e3d1dd105c d7338f8085cb65d3519ba1b8d0f0964251ace31f
- Faster mask creation 2aacb34030c282c4a66eee7f133ce0a893c64351
- Allowing for some blocks to be skipped 6ee0b57858967e35b84a442121b89a0dff6ef333
- Adding LoRA Laning eb804a0a898c3386772ae1dd832b38e57f5e6748
- Splitting the Pallas kernel into shrink/expand variants de6746ac3efb478890bf73735bd0c1beaf589646
- Removing masking when only 1 LoRA adapter is used aad109b4cfa8ada4c4271a5c1b308816dcccf1d9
And a few general ones:
- Pre-transposing the LoRA adapters used in the
expandop a82f3fe038c77d43d1ec7713fccf1d7fc9412ec0 - Reducing recompilations 5638e7da75905acebfa3dac3844516c8cd068dc5
Things left/RFC
- There are still a few recompilations at the start of a run that I need to track down
LogitsProcessorWithLoRAintroduces a long (~1.5 second) stall when it's enabled, but not much activity seems to happen on the CPU or TPU during this time. I've disabled this for now.- It seems
LogitsProcessorWithLoRAis always created even if there's no LoRA adapter that needs it, is there a reason for this? - I have microbenchmarks for the kernels, but I'm not sure what the right place to put them is.
👋 Hi! Thank you for contributing to the vLLM project.
💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.
Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.
To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.
🚀
It seems LogitsProcessorWithLoRA is always created even if there's no LoRA adapter that needs it, is there a reason for this?
The LoRA layers in vLLM are created in advance. Which layers will create LoRA layers is determined by supported_lora_modules, see: https://github.com/vllm-project/vllm/blob/main/vllm/lora/models.py#L470
It seems LogitsProcessorWithLoRA is always created even if there's no LoRA adapter that needs it, is there a reason for this?
The LoRA layers in vLLM are created in advance. Which layers will create LoRA layers is determined by
supported_lora_modules, see: https://github.com/vllm-project/vllm/blob/main/vllm/lora/models.py#L470
Yep that makes sense, but here it seems like LogitsProcessorWithLoRA will always be created. The creation function is different to the other one here
It seems LogitsProcessorWithLoRA is always created even if there's no LoRA adapter that needs it, is there a reason for this?
The LoRA layers in vLLM are created in advance. Which layers will create LoRA layers is determined by
supported_lora_modules, see: https://github.com/vllm-project/vllm/blob/main/vllm/lora/models.py#L470Yep that makes sense, but here it seems like
LogitsProcessorWithLoRAwill always be created. The creation function is different to the other one here
What model are you being tested on? If it's llama, because llama's supported_lora_modules contains this module
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @Akshat-Tripathi.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
@NickLucche I'm almost done with this PR, there's just a problem with performance and the sampler. My smoke tests show about a 10x performance dip when I enable the sampler.
Profiles are showing that there's a stall of about 80-100ms between model execution and sampling. Have you seen something like this before?
Thanks for your work!
a 10x performance dip when I enable the sampler
What do you mean, the sampler is always "enabled" as of now.
there's a stall of about 80-100ms between model execution and sampling
Likely something being compiled, can you confirm this?
I can help look into the recompilations, but we have to start from the first PR . As long as it's not introducing overhead for the regular case, we can likely have it merged even if performance is suboptimal.
Thanks for your work!
a 10x performance dip when I enable the sampler
What do you mean, the sampler is always "enabled" as of now.
there's a stall of about 80-100ms between model execution and sampling
Likely something being compiled, can you confirm this?
I can help look into the recompilations, but we have to start from the first PR . As long as it's not introducing overhead for the regular case, we can likely have it merged even if performance is suboptimal.
Ah I had it commented out earlier when I was testing. I don't see any recompilations causing the stall when I run with PT_XLA_DEBUG=2 or in the profile. Actually almost nothing seems to run, except for a small memory allocation.
I've put the profile below. The first graph executes the model, while the second graph is the sampler's execution. Without LoRA the gap between them is ~800us long.
Commenting out this code seems to fix the problem. But I can't seem to find any reason for the slowdown.
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @Akshat-Tripathi.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
Thanks for your work!
a 10x performance dip when I enable the sampler
What do you mean, the sampler is always "enabled" as of now.
there's a stall of about 80-100ms between model execution and sampling
Likely something being compiled, can you confirm this? I can help look into the recompilations, but we have to start from the first PR . As long as it's not introducing overhead for the regular case, we can likely have it merged even if performance is suboptimal.
Ah I had it commented out earlier when I was testing. I don't see any recompilations causing the stall when I run with
PT_XLA_DEBUG=2or in the profile. Actually almost nothing seems to run, except for a small memory allocation.I've put the profile below. The first graph executes the model, while the second graph is the sampler's execution. Without LoRA the gap between them is ~800us long.
Commenting out this code seems to fix the problem. But I can't seem to find any reason for the slowdown.
Turns out the memory allocations needed for padding don't show up properly in traces. I've fixed the issue now
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @Akshat-Tripathi.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
I've got some performance numbers using the MLPerf Llama2-70B inference benchmark. I retokenised the dataset for Llama3.1.
| Model | Parameters | Without LoRA (tok/s) | With LoRA (tok/s) |
|---|---|---|---|
| Llama3.1 | 8B | 1621.62 | 1426.01 |
| Llama3.1 | 70B | 432.964 | 326.709 |
Benchmarking LoRA against baseline (no LoRA) throughput
We use NVIDIA's GenAI-Perf tool to force fixed-length inputs and outputs to produce "heatmap" plots as below. On TPU-v6e and H100 instances, we vary the inputs from 128 to 8k. On L4 instances, we vary the inputs from 128 to 2k.
We calculate the LoRA slowdown as ((LoRA throughput / baseline throughput) - 1) * 100%.
Llama3.1-8B
1x TPU-v6e
The LoRA slowdown varies from -8.4% to -23.9%.
1x GPU-L4
The LoRA slowdown varies from -17.3% to -32.8%.
1x GPU-H100
The LoRA slowdown varies from -10.0% to -51.8%.
Llama3.1-70B
8x TPU-v6e
The LoRA slowdown varies from -20.7% to -46.3%.
8x GPU-L4
The LoRA slowdown varies from -13.8% (second best: -25.1%) to -49.7%.
4x GPU-H100
Unable to launch VMs due to persistent unavailability across multiple zones and regions.
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @Akshat-Tripathi.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
Thanks for the contribution of milti-lora.
I tried running test_lora.py, the compilation time is around 1000s, while it's around 100s when multi-lora is not enabled.
It is very surprising especially the value of
max_lorasin the test is only 2.
Oh yes I've just reproduced that. I think the longer compilation times come from a recent merge, that test used to take ~24 minutes total to run
Oh yes I've just reproduced that. I think the longer compilation times come from a recent merge, that test used to take ~24 minutes total to run
Could you explain why it takes so long and what's the path to fix that? I'm fine to address it in a separate PR, but we need to understand the issue at the moment.
In general, could you summarize what TPU computation is involved with regarding to lora, I think it should include but maybe not limited to:
- load lora
- remove lora
- pin lora
- set active lora
- add new lora
- execute model and sampler with lora
What are the possible shapes of the TPU computations? Especially when we consider different number of loras?
On TPU-v6e and H100 instances, we vary the inputs from 128 to 8k. On L4 instances, we vary the inputs from 128 to 2k.
Thanks for the awesome heatmap! Could you share the script to reproduce the performance benchmarking?
Oh yes I've just reproduced that. I think the longer compilation times come from a recent merge, that test used to take ~24 minutes total to run
Could you explain why it takes so long and what's the path to fix that? I'm fine to address it in a separate PR, but we need to understand the issue at the moment.
In general, could you summarize what TPU computation is involved with regarding to lora, I think it should include but maybe not limited to:
- load lora
- remove lora
- pin lora
- set active lora
- add new lora
- execute model and sampler with lora
What are the possible shapes of the TPU computations? Especially when we consider different number of loras?
I still haven't figured out exactly what's causing this. The slowdown seems to be mainly coming from the model's forward pass, testing with Qwen2.5-3B, max_num_batched_tokens=16 and only compiling 1 layer, it takes 16s for 1 LoRA, 30s for 2 and 45s for 3.
I don't see any evidence that we're recompiling the model n times, so I need to look into it. My feeling is that it's a problem with how maybe_dummy_run_with_lora works.
The TPU computations are:
- Load the LoRAfied model
- Copy a LoRA adapter from CPU to TPU
- Update the LoRA metadata tensors
- Run LoRA in the forward pass
- Run LoRA during sampling (LogitsProcessorWithLoRA) for the models which have a LogitsProcessor
I don't see any evidence that we're recompiling the model n times, so I need to look into it. My feeling is that it's a problem with how
maybe_dummy_run_with_loraworks.
Recently I wrote one page about how to avoid recompilation when using PyTorch/XLA. https://github.com/vllm-project/vllm/issues/16282
Update the LoRA metadata tensors
And I'd like to recommend there should be no TPU computation involved in the input preparing stage.
Recently I wrote one page about how to avoid recompilation when using PyTorch/XLA. #16282
That looks great thanks! It'd be great to have something like it on the pytorch_xla docs page too.
Unfortunately I don't think the problem is recompilation, but loading dummy lora adapters.
And I'd like to recommend there should be no TPU computation involved in the input preparing stage.
Yep earlier I moved the input prep computation to the CPU, here. But immediately after we copy to the TPU tensors. Does the copy first allocate new TPU tensors? Or would it implicitly donate the buffer?
Thanks for sharing the extensive benchmarks @psyhtest !
Any more work on landing https://github.com/vllm-project/vllm/pull/14238? I think it's going to get increasingly complex for @Akshat-Tripathi to maintain both PRs given the size of the contribution.
If we're happy with the design proposed in #14238 but need to address recompilations, we can still land that while keeping lora unsupported until this PR sorts it out.
Alternatively we could close #14238 and try to go from 0-100 here, but I would prefer having smaller incremental PRs tbh and isolate the work, I feel it would help in reviewing.
Yep earlier I moved the input prep computation to the CPU, here. But immediately after we copy to the TPU tensors. Does the copy first allocate new TPU tensors? Or would it implicitly donate the buffer?
A copy is a copy, but given that everything is traced here the runtime can pre-allocate the buffers unlike eager execution.
but loading dummy lora adapters.
But loading adapters shouldn't take so long. Anyway the weight size is very small. And the weight manipulation should only contains small graph(s).
Yep earlier I moved the input prep computation to the CPU, here.
Seems It cannot jump to the correct place.
Does the copy first allocate new TPU tensors? Or would it implicitly donate the buffer?
xla_tensor.copy_(...) will first allocate a new TPU tensor, there's no real inplace cpu-to-tpu copy at the moment.
But loading adapters shouldn't take so long. Anyway the weight size is very small. And the weight manipulation
Yes and no, the input graph ends up being quite big, around 1000 inputs and 500 outputs for 1 layer of Qwen2.5-3B. I've got a fix for that so just testing now. We also load padded weights, which makes a difference
Seems It cannot jump to the correct place.
Ah here's the link to the fork
xla_tensor.copy_(...) will first allocate a new TPU tensor, there's no real inplace cpu-to-tpu copy at the moment.
Ah this may cause an issue, but I'll have to think more
Yes and no, the input graph ends up being quite big, around 1000 inputs and 500 outputs for 1 layer of Qwen2.5-3B
Sounds strange. I don't think we even have so many operations in layer. Not say the loading adapters part.
Ah here's the link to the fork
self._token_lora_indices[:base_indices.shape[0]] = base_indices.to(self.device)
Code like this will create slice operation. And base_indices.to(self.device) will create a new TPU buffer first. What I recommend in the doc is to use self._token_lora_indices= base_indices.to(self.device). We cannot reuse the buffer but it won't create new XLA ops, which make handling recompilation easier.
Code like this will create slice operation. And
base_indices.to(self.device)will create a new TPU buffer first. What I recommend in the doc is to useself._token_lora_indices= base_indices.to(self.device). We cannot reuse the buffer but it won't create new XLA ops, which make handling recompilation easier.
Ah ok, the problem is that the copies are in shared code, e.g. here. We've managed to split up the large graph now, and that gets rid of the large allocated buffers.
Those changes are in another branch right now, not sure whether to make them part of this PR
If we're happy with the design proposed in #14238 but need to address recompilations, we can still land that while keeping lora unsupported until this PR sorts it out.
Thanks @NickLucche, yeah it's getting tricky maintaining both these PRs. I'm happy to disable LoRA in #14238 and reenable it here. There's might be 1 more PR that just has some compilation time / memory usage optimisations too soon
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @Akshat-Tripathi.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @Akshat-Tripathi.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @Akshat-Tripathi.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
