vllm
vllm copied to clipboard
[torch.compile] integration with compilation control
closes https://github.com/vllm-project/vllm/issues/8821
this pr defines the way we integrate torch.compile . Some features of vLLM are not tested yet, so I just put the functionality under a test flag VLLM_TEST_TORCH_COMPILE_LEVEL .
the user-facing flags are:
flags to control Dynamo graph capture
| graph capture | VLLM_DYNAMO_USE_CUSTOM_DISPATCHER == 0 | VLLM_DYNAMO_USE_CUSTOM_DISPATCHER == 1 (default) |
|---|---|---|
| VLLM_TEST_TORCH_COMPILE_LEVEL == 0 ( default) | no Dynamo | no Dynamo |
| VLLM_TEST_TORCH_COMPILE_LEVEL > 0 | use Dynamo as-is, sound, but introduces runtime guard evaluation overhead | only let Dynamo run once, and then dispatch the bytecode ourselves. no dynamo overhead, but can have correctness issue |
Most models have static computation graph, and it is safe to use vLLM's custom dispatcher. That's why we can make it default. However, we do recommend users run with the following settings to make sure the outputs are the same:
- default setting
export VLLM_TEST_TORCH_COMPILE_LEVEL=1andexport VLLM_DYNAMO_USE_CUSTOM_DISPATCHER=0export VLLM_TEST_TORCH_COMPILE_LEVEL=1andexport VLLM_DYNAMO_USE_CUSTOM_DISPATCHER=1
if the output from 2 is different from 1, it means a torch.compile bug, and we should inform the pytorch team.
if the output 3 is different from 1, it means the model is not able to be captured as a single static graph. We should investigate from vLLM side.
flags to control Inductor compilation
assuming VLLM_TEST_TORCH_COMPILE_LEVEL > 0
| VLLM_TEST_TORCH_COMPILE_LEVEL | use Inductor | use vLLM custom ops |
|---|---|---|
| 1 | ❎ | ✅ |
| 2 | ✅ | ❎ |
| 3 | ✅ (with max-autotune) | ❎ |
control of the number of compilations
by default, we compile:
- one graph that works for arbitrary number of input tokens
- one graph for each batchsize of cudagraph capture batchsize
👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.
Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.
To run CI, PR reviewers can do one of these:
- Add
readylabel to the PR - Enable auto-merge.
🚀
for the initial benchmark results, see https://github.com/vllm-project/vllm/pull/8949#issuecomment-2381590243
feature compatibility and TODOs (will be future PRs):
- [ ] support embedding model
- [ ] support encoder-decoder model
- [x] support multi-modality model (llava is supported)
- [x] support attention backend other than flash attention (FLASHINFER is supported)
- [x] support models other than llama (locally tested gemma 2
google/gemma-2-2b-it, just one line change to make it work) - [x] support TP (at the cost of silently disabling custom allreduce. should be fixed after making all-reduce out-of-place, see https://github.com/vllm-project/vllm/pull/9061 )
- [x] support PP
- [ ] test and integrate lora
- [x] test and integrate quantization (compressed-tensors is tested)
- [x] perf testing
- [ ] profile and investigate compilation time reduction
our goal is to turn on torch.compile by default at the second level (i.e. VLLM_TEST_TORCH_COMPILE_LEVEL == 2 )
@bnellnm can you take a look?
new design:
Usage CompilationLevel how vLLM uses Dynamo how vLLM uses Inductor use vLLM's custom ops how to customize the compilation
export VLLM_TORCH_COMPILE_LEVEL=0(default) NO_COMPILATION (0) N/A N/A ✅ N/Aexport VLLM_TORCH_COMPILE_LEVEL=1DYNAMO_AS_IS (1) use as-is N/A ✅vllm.plugins.set_torch_compile_backendexport VLLM_TORCH_COMPILE_LEVEL=2DYNAMO_ONCE (2) use only once, make sure computation graph does not change N/A ✅vllm.plugins.set_torch_compile_backendexport VLLM_TORCH_COMPILE_LEVEL=3INDUCTOR (3) same as 2 compile one graph with symbolic shape, and many graphs for each cudagraph shape ❌vllm.plugins.set_inductor_additional_configsexport VLLM_TORCH_COMPILE_LEVEL=4INDUCTOR_MAX_AUTOTUNE (4) same as 2 same as 3 ❌vllm.plugins.set_inductor_additional_configs, withmax_autotuneforced to beTrueUsage guide:run vLLM with:
- default setting
export VLLM_TORCH_COMPILE_LEVEL=1export VLLM_TORCH_COMPILE_LEVEL=2if the output from 2 is different from 1, it means a
torch.compilebug, and we should inform the pytorch team.if the output 3 is different from 1, it means the model is not able to be captured as a single static graph. We should investigate from vLLM side.
if they all produce the same outputs, then try
export VLLM_TORCH_COMPILE_LEVEL=3orexport VLLM_TORCH_COMPILE_LEVEL=4, depending on the compilation time you want to pay.
I think this would make a good README.md or addition to the docs. Not sure where the best place would be.
I think this would make a good README.md or addition to the docs. Not sure where the best place would be.
sure we will add docs for it, after stable integration.
In general, I think this PR is reasonable. There are a few higher level items that I would like to call out, though.
- It would be nice to have some finer grained control over which custom ops are enabled/disabled when running with higher optimization levels. This is useful because it makes writing inductor patterns easier. Especially since we are matching after post_grad passes run. We don't necessarily want to worry about torch.compile optimizations making it so that our pattern no longer matches. The two usecases we have now are silu_and_mul and fused_add_rms_norm both being fused with static_scaled_fp8_quant via custom kernels.
- Perhaps unsurprisingly, the torch.compile compile times with cuda graphs have significantly increased. The following command spends about 12 minutes compiling for various shapes.
VLLM_TORCH_COMPILE_LEVEL=4 python benchmarks/benchmark_latency.py --model "neuralmagic/Meta-Llama-3-8B-Instruct-FP8" -q "fp8" --dtype "float16"
It would be nice to have some finer grained control over which custom ops are enabled/disabled when running with higher optimization levels.
I'd like to support that if we have a good way to expose it. Adding one env var for each custom op might be too complicated? Currently, when I am experimenting, I just manually change the code of every custom op to select the implementation.
The following command spends about 12 minutes compiling for various shapes
compilation time is the focus of followup PRs. I need to get this merged first, so that pytorch team can start doing their work. it is not polite to ask them doing optimization on an experiment branch.
I'd like to support that if we have a good way to expose it. Adding one env var for each custom op might be too complicated? Currently, when I am experimenting, I just manually change the code of every custom op to select the implementation.
I do the same thing which is completely fine for experimentation. I agree that one env var for each custom op is a bit much. Overriding the dispatch_forward method in custom ops we want to enable could be a reasonable way forward, though.
compilation time is the focus of followup PRs. I need to get this merged first, so that pytorch team can start doing their work. it is not polite to ask them doing optimization on an experiment branch.
Completely understandable, just wanted to call it out.
Overriding the dispatch_forward method in custom ops we want to enable could be a reasonable way forward
this is what I did for experiments. do you have any ideas on how to expose the control to users?
Overriding the dispatch_forward method in custom ops we want to enable could be a reasonable way forward
this is what I did for experiments. do you have any ideas on how to expose the control to users?
Not immediately. I think we should discuss this in tandem with how/if we want to expose enabling/disabling various custom inductor passes, though.
this is what I did for experiments. do you have any ideas on how to expose the control to users?
What about one environment variable that serves as a list of custom ops that are enabled, using + and - to enable/disable:
VLLM_ENABLE_CUSTOM_OPS="+rms_norm,-silu_and_mul"
We could also use all and none to mean all of the ops if we want to
This assumes there's a default set that's enabled/disabled. Even if that's not the case, you could do all,-rms_norm to enable all but rms_norm or none,+rms_norm to only enable rms_norm.
this is what I did for experiments. do you have any ideas on how to expose the control to users?
What about one environment variable that serves as a list of custom ops that are enabled, using
+and-to enable/disable:VLLM_ENABLE_CUSTOM_OPS="+rms_norm,-silu_and_mul"We could also use
allandnoneto mean all of the ops if we want toThis assumes there's a default set that's enabled/disabled. Even if that's not the case, you could do
all,-rms_normto enable all butrms_normornone,+rms_normto only enablerms_norm.
that's a good idea.
there are some errors when we run the test:
Not all values of RelaxedUnspecConstraint(L['input_ids'].size()[0]) are valid because L['input_ids'].size()[0] was inferred to be a constant (2048).
I think it is because quantization kernels don't have good symbolic shape support. Need to investigate further.
run it with TORCH_LOGS="+dynamic", I get:
[rank0]:V1009 22:37:30.433188 139882259515200 torch/fx/experimental/symbolic_shapes.py:2529] [0/0] create_env
[rank0]:I1009 22:37:30.472994 139882259515200 torch/fx/experimental/symbolic_shapes.py:3549] [0/0] create_symbol s0 = 2048 for L['input_ids'].size()[0] [2, 9223372036854775806] at vllm/vllm/model_executor/layers/vocab_parallel_embedding.py:398 in forward (_dynamo/variables/builder.py:2276 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s0"
[rank0]:V1009 22:37:30.473605 139882259515200 torch/fx/experimental/symbolic_shapes.py:5167] [0/0] eval True == True [statically known]
[rank0]:V1009 22:37:30.473842 139882259515200 torch/fx/experimental/symbolic_shapes.py:5167] [0/0] eval False == False [statically known]
[rank0]:V1009 22:37:30.509129 139882259515200 torch/fx/experimental/symbolic_shapes.py:5167] [0/0] eval Ne(s0, 1) == True [statically known]
[rank0]:V1009 22:37:30.509387 139882259515200 torch/fx/experimental/symbolic_shapes.py:5167] [0/0] eval True == True [statically known]
[rank0]:V1009 22:37:30.515596 139882259515200 torch/fx/experimental/symbolic_shapes.py:4697] [0/0] _update_var_to_range s0 = VR[2048, 2048] (update)
[rank0]:I1009 22:37:30.516592 139882259515200 torch/fx/experimental/symbolic_shapes.py:4831] [0/0] set_replacement s0 = 2048 (solve) VR[2048, 2048]
[rank0]:I1009 22:37:30.516976 139882259515200 torch/fx/experimental/symbolic_shapes.py:5082] [0/0] eval Eq(s0, 2048) [guard added] at vllm/vllm/_custom_ops.py:308 in gptq_marlin_24_gemm (_ops.py:1089 in _call_overload_packet_from_python), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s0, 2048)"
so this line vllm/vllm/_custom_ops.py:308 in gptq_marlin_24_gemm forces the symbolic shape to be a constant, because it feeds the shape to a c++ op torch.ops._C.gptq_marlin_24_gemm .
The tests need to be re-written as comparison tests anyway. So I will comment them out right now.
merge to trigger wheel builds early