run_vit_b_quant.py runs slower than run_bit_b.py
run_vit_b_quant.py elapsed_time: 11.0519150390625 milliseconds
run_bit_b.py elapsed_time: 1.2272755432128906 milliseconds
this is with int8_dynamic_activation_int8_weight
it seems torch==2.4.0 does not have the drop (with unwrap_tensor_subclass)
run_vit_b_quant.py elapsed_time: 1.288721923828125 milliseconds
run_bit_b.py elapsed_time: 1.561510772705078 milliseconds
int8_weight_only run_vit_b_quant.py elapsed_time: 1.3892197265625 milliseconds
run_bit_b.py elapsed_time: 1.543534698486328 milliseconds
hmm. This might be a good example of "subclass runtime overhead" given the fact that you're pointing out that you see the slowdown goes away on 2.4.0 when using unwrap_tensor_subclasses. But it would be nice to have a profiler trace that actually shows us that most of the time is spent in python overhead and not e.g. compile generating a slower artifact. @jerryzh168 any chance you can get a profile output? Also cc @IvanKobzarev
Yeah, will try to repro and profile it.
we'll be able to cherry-pick the change until 9/30
hmm. This might be a good example of "subclass runtime overhead" given the fact that you're pointing out that you see the slowdown goes away on 2.4.0 when using
unwrap_tensor_subclasses. But it would be nice to have a profiler trace that actually shows us that most of the time is spent in python overhead and not e.g. compile generating a slower artifact. @jerryzh168 any chance you can get a profile output? Also cc @IvanKobzarev
@bdhirsh i thought compile would trace through the subclass, you're saying there's still a bunch of overhead for subclasses even after compile?
Found the problem. The main regression is because of dynamo fails to compile fullgraph=True, as a result compiles it partially with graph break on every MultiHeadAttention call and that causes a bad perf.
The compilation fails because compile path picks multi head attention "fast-path".
https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/activation.py#L1286 - there is a manual check to avoid fastpath (native_multi_head_attention) if one of the arguments has torch_function handling.
But this check fail for Subclasses during compilation and compilation tries to compile fastpath via aten.native_multi_head_attention and results in NYI for subclass.
If to take not-fast-path during compilation - benchmark for me shows 1.21ms back
So there is no significant runtime overhead for subclasses, just compilation issue of MultiHeadAttention when there is a subclass as a parameter.
Now thinking on the fix how to make ao subclasses to take only non-fast-path for MultiHeadAttention during compilation.
@IvanKobzarev - You should be able to use https://pytorch.org/docs/main/backends.html#torch.backends.mha.set_fastpath_enabled to disable the fast path.
@cpuhrsch Thanks, this helps.
@jerryzh168 , I've verified, adding torch.backends.mha.set_fastpath_enabled(False) to run_vit_b_quant.py at the top gets back performance without unwrap_tensor_subclasses
elapsed_time: 1.216195556640625 milliseconds
I will leave it to you where to put torch.backends.mha.set_fastpath_enabled(False) in AO, that AO-quantized models will not take mha.fastpath.
I'd add this setting into the run_vit_b_quant.py example script. We might also want to consider adding a warning to PyTorch when the fast path is enabled and subclasses are used (i.e. one of the arguments has a torch_function).