vllm icon indicating copy to clipboard operation
vllm copied to clipboard

[Hardware][TPU][V1] Multi-LoRA Optimisations for the V1 TPU backend

Open Akshat-Tripathi opened this issue 8 months ago • 36 comments

Summary

This PR optimises the Multi-LoRA implementation from https://github.com/vllm-project/vllm/pull/14238/. This one should be merged in after it.

This includes several kernel optimisations:

  • Block size tuning 2bb886899f540463a43c47ec81ae77e3d1dd105c d7338f8085cb65d3519ba1b8d0f0964251ace31f
  • Faster mask creation 2aacb34030c282c4a66eee7f133ce0a893c64351
  • Allowing for some blocks to be skipped 6ee0b57858967e35b84a442121b89a0dff6ef333
  • Adding LoRA Laning eb804a0a898c3386772ae1dd832b38e57f5e6748
  • Splitting the Pallas kernel into shrink/expand variants de6746ac3efb478890bf73735bd0c1beaf589646
  • Removing masking when only 1 LoRA adapter is used aad109b4cfa8ada4c4271a5c1b308816dcccf1d9

And a few general ones:

  • Pre-transposing the LoRA adapters used in the expand op a82f3fe038c77d43d1ec7713fccf1d7fc9412ec0
  • Reducing recompilations 5638e7da75905acebfa3dac3844516c8cd068dc5

Things left/RFC

  • There are still a few recompilations at the start of a run that I need to track down
  • LogitsProcessorWithLoRA introduces a long (~1.5 second) stall when it's enabled, but not much activity seems to happen on the CPU or TPU during this time. I've disabled this for now.
  • It seems LogitsProcessorWithLoRA is always created even if there's no LoRA adapter that needs it, is there a reason for this?
  • I have microbenchmarks for the kernels, but I'm not sure what the right place to put them is.

Akshat-Tripathi avatar Mar 27 '25 23:03 Akshat-Tripathi

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

github-actions[bot] avatar Mar 27 '25 23:03 github-actions[bot]

It seems LogitsProcessorWithLoRA is always created even if there's no LoRA adapter that needs it, is there a reason for this?

The LoRA layers in vLLM are created in advance. Which layers will create LoRA layers is determined by supported_lora_modules, see: https://github.com/vllm-project/vllm/blob/main/vllm/lora/models.py#L470

jeejeelee avatar Mar 27 '25 23:03 jeejeelee

It seems LogitsProcessorWithLoRA is always created even if there's no LoRA adapter that needs it, is there a reason for this?

The LoRA layers in vLLM are created in advance. Which layers will create LoRA layers is determined by supported_lora_modules, see: https://github.com/vllm-project/vllm/blob/main/vllm/lora/models.py#L470

Yep that makes sense, but here it seems like LogitsProcessorWithLoRA will always be created. The creation function is different to the other one here

Akshat-Tripathi avatar Mar 27 '25 23:03 Akshat-Tripathi

It seems LogitsProcessorWithLoRA is always created even if there's no LoRA adapter that needs it, is there a reason for this?

The LoRA layers in vLLM are created in advance. Which layers will create LoRA layers is determined by supported_lora_modules, see: https://github.com/vllm-project/vllm/blob/main/vllm/lora/models.py#L470

Yep that makes sense, but here it seems like LogitsProcessorWithLoRA will always be created. The creation function is different to the other one here

What model are you being tested on? If it's llama, because llama's supported_lora_modules contains this module

jeejeelee avatar Mar 28 '25 01:03 jeejeelee

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @Akshat-Tripathi.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Mar 29 '25 04:03 mergify[bot]

@NickLucche I'm almost done with this PR, there's just a problem with performance and the sampler. My smoke tests show about a 10x performance dip when I enable the sampler.

Profiles are showing that there's a stall of about 80-100ms between model execution and sampling. Have you seen something like this before?

Akshat-Tripathi avatar Mar 31 '25 22:03 Akshat-Tripathi

Thanks for your work!

a 10x performance dip when I enable the sampler

What do you mean, the sampler is always "enabled" as of now.

there's a stall of about 80-100ms between model execution and sampling

Likely something being compiled, can you confirm this?

I can help look into the recompilations, but we have to start from the first PR . As long as it's not introducing overhead for the regular case, we can likely have it merged even if performance is suboptimal.

NickLucche avatar Apr 01 '25 07:04 NickLucche

Thanks for your work!

a 10x performance dip when I enable the sampler

What do you mean, the sampler is always "enabled" as of now.

there's a stall of about 80-100ms between model execution and sampling

Likely something being compiled, can you confirm this?

I can help look into the recompilations, but we have to start from the first PR . As long as it's not introducing overhead for the regular case, we can likely have it merged even if performance is suboptimal.

Ah I had it commented out earlier when I was testing. I don't see any recompilations causing the stall when I run with PT_XLA_DEBUG=2 or in the profile. Actually almost nothing seems to run, except for a small memory allocation.

I've put the profile below. The first graph executes the model, while the second graph is the sampler's execution. Without LoRA the gap between them is ~800us long. image

Commenting out this code seems to fix the problem. But I can't seem to find any reason for the slowdown.

Akshat-Tripathi avatar Apr 01 '25 22:04 Akshat-Tripathi

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @Akshat-Tripathi.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Apr 02 '25 04:04 mergify[bot]

Thanks for your work!

a 10x performance dip when I enable the sampler

What do you mean, the sampler is always "enabled" as of now.

there's a stall of about 80-100ms between model execution and sampling

Likely something being compiled, can you confirm this? I can help look into the recompilations, but we have to start from the first PR . As long as it's not introducing overhead for the regular case, we can likely have it merged even if performance is suboptimal.

Ah I had it commented out earlier when I was testing. I don't see any recompilations causing the stall when I run with PT_XLA_DEBUG=2 or in the profile. Actually almost nothing seems to run, except for a small memory allocation.

I've put the profile below. The first graph executes the model, while the second graph is the sampler's execution. Without LoRA the gap between them is ~800us long. image

Commenting out this code seems to fix the problem. But I can't seem to find any reason for the slowdown.

Turns out the memory allocations needed for padding don't show up properly in traces. I've fixed the issue now

Akshat-Tripathi avatar Apr 02 '25 04:04 Akshat-Tripathi

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @Akshat-Tripathi.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Apr 03 '25 09:04 mergify[bot]

I've got some performance numbers using the MLPerf Llama2-70B inference benchmark. I retokenised the dataset for Llama3.1.

Model Parameters Without LoRA (tok/s) With LoRA (tok/s)
Llama3.1 8B 1621.62 1426.01
Llama3.1 70B 432.964 326.709

Akshat-Tripathi avatar Apr 07 '25 08:04 Akshat-Tripathi

Benchmarking LoRA against baseline (no LoRA) throughput

We use NVIDIA's GenAI-Perf tool to force fixed-length inputs and outputs to produce "heatmap" plots as below. On TPU-v6e and H100 instances, we vary the inputs from 128 to 8k. On L4 instances, we vary the inputs from 128 to 2k.

We calculate the LoRA slowdown as ((LoRA throughput / baseline throughput) - 1) * 100%.

Llama3.1-8B

1x TPU-v6e

The LoRA slowdown varies from -8.4% to -23.9%.

Llama3 1-8B_1xTPU-v6e

1x GPU-L4

The LoRA slowdown varies from -17.3% to -32.8%.

Llama3 1-8B_1xGPU-L4_v2

1x GPU-H100

The LoRA slowdown varies from -10.0% to -51.8%.

Llama3 1-8B_1xGPU-H100

Llama3.1-70B

8x TPU-v6e

The LoRA slowdown varies from -20.7% to -46.3%.

Llama3 1-70B_8xTPU-v6e

8x GPU-L4

The LoRA slowdown varies from -13.8% (second best: -25.1%) to -49.7%.

Llama3 1-70B_8xGPU-L4

4x GPU-H100

Unable to launch VMs due to persistent unavailability across multiple zones and regions.

psyhtest avatar Apr 07 '25 21:04 psyhtest

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @Akshat-Tripathi.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Apr 10 '25 11:04 mergify[bot]

Thanks for the contribution of milti-lora.

I tried running test_lora.py, the compilation time is around 1000s, while it's around 100s when multi-lora is not enabled.

It is very surprising especially the value of max_loras in the test is only 2.

Oh yes I've just reproduced that. I think the longer compilation times come from a recent merge, that test used to take ~24 minutes total to run

Akshat-Tripathi avatar Apr 14 '25 12:04 Akshat-Tripathi

Oh yes I've just reproduced that. I think the longer compilation times come from a recent merge, that test used to take ~24 minutes total to run

Could you explain why it takes so long and what's the path to fix that? I'm fine to address it in a separate PR, but we need to understand the issue at the moment.

In general, could you summarize what TPU computation is involved with regarding to lora, I think it should include but maybe not limited to:

  1. load lora
  2. remove lora
  3. pin lora
  4. set active lora
  5. add new lora
  6. execute model and sampler with lora

What are the possible shapes of the TPU computations? Especially when we consider different number of loras?

yaochengji avatar Apr 14 '25 18:04 yaochengji

On TPU-v6e and H100 instances, we vary the inputs from 128 to 8k. On L4 instances, we vary the inputs from 128 to 2k.

Thanks for the awesome heatmap! Could you share the script to reproduce the performance benchmarking?

yaochengji avatar Apr 15 '25 06:04 yaochengji

Oh yes I've just reproduced that. I think the longer compilation times come from a recent merge, that test used to take ~24 minutes total to run

Could you explain why it takes so long and what's the path to fix that? I'm fine to address it in a separate PR, but we need to understand the issue at the moment.

In general, could you summarize what TPU computation is involved with regarding to lora, I think it should include but maybe not limited to:

  1. load lora
  2. remove lora
  3. pin lora
  4. set active lora
  5. add new lora
  6. execute model and sampler with lora

What are the possible shapes of the TPU computations? Especially when we consider different number of loras?

I still haven't figured out exactly what's causing this. The slowdown seems to be mainly coming from the model's forward pass, testing with Qwen2.5-3B, max_num_batched_tokens=16 and only compiling 1 layer, it takes 16s for 1 LoRA, 30s for 2 and 45s for 3.

I don't see any evidence that we're recompiling the model n times, so I need to look into it. My feeling is that it's a problem with how maybe_dummy_run_with_lora works.

The TPU computations are:

  1. Load the LoRAfied model
  2. Copy a LoRA adapter from CPU to TPU
  3. Update the LoRA metadata tensors
  4. Run LoRA in the forward pass
  5. Run LoRA during sampling (LogitsProcessorWithLoRA) for the models which have a LogitsProcessor

Akshat-Tripathi avatar Apr 15 '25 17:04 Akshat-Tripathi

I don't see any evidence that we're recompiling the model n times, so I need to look into it. My feeling is that it's a problem with how maybe_dummy_run_with_lora works.

Recently I wrote one page about how to avoid recompilation when using PyTorch/XLA. https://github.com/vllm-project/vllm/issues/16282

Update the LoRA metadata tensors

And I'd like to recommend there should be no TPU computation involved in the input preparing stage.

yaochengji avatar Apr 16 '25 01:04 yaochengji

Recently I wrote one page about how to avoid recompilation when using PyTorch/XLA. #16282

That looks great thanks! It'd be great to have something like it on the pytorch_xla docs page too.

Unfortunately I don't think the problem is recompilation, but loading dummy lora adapters.

And I'd like to recommend there should be no TPU computation involved in the input preparing stage.

Yep earlier I moved the input prep computation to the CPU, here. But immediately after we copy to the TPU tensors. Does the copy first allocate new TPU tensors? Or would it implicitly donate the buffer?

Akshat-Tripathi avatar Apr 16 '25 08:04 Akshat-Tripathi

Thanks for sharing the extensive benchmarks @psyhtest !

Any more work on landing https://github.com/vllm-project/vllm/pull/14238? I think it's going to get increasingly complex for @Akshat-Tripathi to maintain both PRs given the size of the contribution.

If we're happy with the design proposed in #14238 but need to address recompilations, we can still land that while keeping lora unsupported until this PR sorts it out.

Alternatively we could close #14238 and try to go from 0-100 here, but I would prefer having smaller incremental PRs tbh and isolate the work, I feel it would help in reviewing.

NickLucche avatar Apr 16 '25 08:04 NickLucche

Yep earlier I moved the input prep computation to the CPU, here. But immediately after we copy to the TPU tensors. Does the copy first allocate new TPU tensors? Or would it implicitly donate the buffer?

A copy is a copy, but given that everything is traced here the runtime can pre-allocate the buffers unlike eager execution.

NickLucche avatar Apr 16 '25 08:04 NickLucche

but loading dummy lora adapters.

But loading adapters shouldn't take so long. Anyway the weight size is very small. And the weight manipulation should only contains small graph(s).

Yep earlier I moved the input prep computation to the CPU, here.

Seems It cannot jump to the correct place.

Does the copy first allocate new TPU tensors? Or would it implicitly donate the buffer?

xla_tensor.copy_(...) will first allocate a new TPU tensor, there's no real inplace cpu-to-tpu copy at the moment.

yaochengji avatar Apr 16 '25 16:04 yaochengji

But loading adapters shouldn't take so long. Anyway the weight size is very small. And the weight manipulation

Yes and no, the input graph ends up being quite big, around 1000 inputs and 500 outputs for 1 layer of Qwen2.5-3B. I've got a fix for that so just testing now. We also load padded weights, which makes a difference

Seems It cannot jump to the correct place.

Ah here's the link to the fork

xla_tensor.copy_(...) will first allocate a new TPU tensor, there's no real inplace cpu-to-tpu copy at the moment.

Ah this may cause an issue, but I'll have to think more

Akshat-Tripathi avatar Apr 16 '25 17:04 Akshat-Tripathi

Yes and no, the input graph ends up being quite big, around 1000 inputs and 500 outputs for 1 layer of Qwen2.5-3B

Sounds strange. I don't think we even have so many operations in layer. Not say the loading adapters part.

Ah here's the link to the fork

self._token_lora_indices[:base_indices.shape[0]] = base_indices.to(self.device)

Code like this will create slice operation. And base_indices.to(self.device) will create a new TPU buffer first. What I recommend in the doc is to use self._token_lora_indices= base_indices.to(self.device). We cannot reuse the buffer but it won't create new XLA ops, which make handling recompilation easier.

yaochengji avatar Apr 16 '25 17:04 yaochengji

Code like this will create slice operation. And base_indices.to(self.device) will create a new TPU buffer first. What I recommend in the doc is to use self._token_lora_indices= base_indices.to(self.device). We cannot reuse the buffer but it won't create new XLA ops, which make handling recompilation easier.

Ah ok, the problem is that the copies are in shared code, e.g. here. We've managed to split up the large graph now, and that gets rid of the large allocated buffers.

Those changes are in another branch right now, not sure whether to make them part of this PR

Akshat-Tripathi avatar Apr 22 '25 11:04 Akshat-Tripathi

If we're happy with the design proposed in #14238 but need to address recompilations, we can still land that while keeping lora unsupported until this PR sorts it out.

Thanks @NickLucche, yeah it's getting tricky maintaining both these PRs. I'm happy to disable LoRA in #14238 and reenable it here. There's might be 1 more PR that just has some compilation time / memory usage optimisations too soon

Akshat-Tripathi avatar Apr 22 '25 11:04 Akshat-Tripathi

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @Akshat-Tripathi.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar May 06 '25 15:05 mergify[bot]

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @Akshat-Tripathi.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar May 07 '25 01:05 mergify[bot]

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @Akshat-Tripathi.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar May 07 '25 20:05 mergify[bot]