ao icon indicating copy to clipboard operation
ao copied to clipboard

run_vit_b_quant.py runs slower than run_bit_b.py

Open jerryzh168 opened this issue 1 year ago • 9 comments

run_vit_b_quant.py elapsed_time: 11.0519150390625 milliseconds

run_bit_b.py elapsed_time: 1.2272755432128906 milliseconds

this is with int8_dynamic_activation_int8_weight

jerryzh168 avatar Sep 17 '24 02:09 jerryzh168

it seems torch==2.4.0 does not have the drop (with unwrap_tensor_subclass)

run_vit_b_quant.py elapsed_time: 1.288721923828125 milliseconds

run_bit_b.py elapsed_time: 1.561510772705078 milliseconds

int8_weight_only run_vit_b_quant.py elapsed_time: 1.3892197265625 milliseconds

run_bit_b.py elapsed_time: 1.543534698486328 milliseconds

jerryzh168 avatar Sep 17 '24 03:09 jerryzh168

hmm. This might be a good example of "subclass runtime overhead" given the fact that you're pointing out that you see the slowdown goes away on 2.4.0 when using unwrap_tensor_subclasses. But it would be nice to have a profiler trace that actually shows us that most of the time is spent in python overhead and not e.g. compile generating a slower artifact. @jerryzh168 any chance you can get a profile output? Also cc @IvanKobzarev

bdhirsh avatar Sep 17 '24 15:09 bdhirsh

Yeah, will try to repro and profile it.

IvanKobzarev avatar Sep 17 '24 15:09 IvanKobzarev

we'll be able to cherry-pick the change until 9/30

jerryzh168 avatar Sep 18 '24 00:09 jerryzh168

hmm. This might be a good example of "subclass runtime overhead" given the fact that you're pointing out that you see the slowdown goes away on 2.4.0 when using unwrap_tensor_subclasses. But it would be nice to have a profiler trace that actually shows us that most of the time is spent in python overhead and not e.g. compile generating a slower artifact. @jerryzh168 any chance you can get a profile output? Also cc @IvanKobzarev

@bdhirsh i thought compile would trace through the subclass, you're saying there's still a bunch of overhead for subclasses even after compile?

HDCharles avatar Sep 20 '24 17:09 HDCharles

Found the problem. The main regression is because of dynamo fails to compile fullgraph=True, as a result compiles it partially with graph break on every MultiHeadAttention call and that causes a bad perf.

The compilation fails because compile path picks multi head attention "fast-path".

https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/activation.py#L1286 - there is a manual check to avoid fastpath (native_multi_head_attention) if one of the arguments has torch_function handling.

But this check fail for Subclasses during compilation and compilation tries to compile fastpath via aten.native_multi_head_attention and results in NYI for subclass.

If to take not-fast-path during compilation - benchmark for me shows 1.21ms back

So there is no significant runtime overhead for subclasses, just compilation issue of MultiHeadAttention when there is a subclass as a parameter.

Now thinking on the fix how to make ao subclasses to take only non-fast-path for MultiHeadAttention during compilation.

IvanKobzarev avatar Sep 20 '24 20:09 IvanKobzarev

@IvanKobzarev - You should be able to use https://pytorch.org/docs/main/backends.html#torch.backends.mha.set_fastpath_enabled to disable the fast path.

cpuhrsch avatar Sep 21 '24 00:09 cpuhrsch

@cpuhrsch Thanks, this helps. @jerryzh168 , I've verified, adding torch.backends.mha.set_fastpath_enabled(False) to run_vit_b_quant.py at the top gets back performance without unwrap_tensor_subclasses

elapsed_time:  1.216195556640625  milliseconds

I will leave it to you where to put torch.backends.mha.set_fastpath_enabled(False) in AO, that AO-quantized models will not take mha.fastpath.

IvanKobzarev avatar Sep 23 '24 12:09 IvanKobzarev

I'd add this setting into the run_vit_b_quant.py example script. We might also want to consider adding a warning to PyTorch when the fast path is enabled and subclasses are used (i.e. one of the arguments has a torch_function).

cpuhrsch avatar Sep 23 '24 21:09 cpuhrsch