nncf icon indicating copy to clipboard operation
nncf copied to clipboard

`MULTIHEAD_ATTENTION_OUTPUT` ignored patterns don't match "proper" SDPA / Attention

Open ruro opened this issue 2 months ago • 2 comments

🐛 Describe the bug

Currently, the MULTIHEAD_ATTENTION_OUTPUT ignore patterns for onnx and torch only work for "decomposed" versions of attention by matching against MATMUL and SOFTMAX nodes in particular arrangements.

This means that torch models using the fused torch.nn.functional.scaled_dot_product_attention operator and onnx models using the Attention node (opsets 23+) don't get matched.

The edge between MATMUL and SOFTMAX doesn't need to be matched, since it is already "hidden inside" the SDPA / Attention nodes. However, the other MATMUL input should correspond to the V (third) input of SDPA / Attention.

I am willing to look into contributing a fix for this, but I am not 100% sure if I can fully figure this out on my own.

Environment

nncf==2.18.0
torch==2.8.0
about-time==4.2.1
alive-progress==3.3.0
anyio==4.11.0
attrs==25.4.0
autograd==1.8.0
certifi==2025.11.12
click==8.3.0
cma==4.4.0
coloredlogs==15.0.1
contourpy==1.3.3
cycler==0.12.1
Deprecated==1.3.1
dill==0.4.0
filelock==3.20.0
flatbuffers==25.9.23
fonttools==4.60.1
fsspec==2025.10.0
graphemeu==0.7.2
h11==0.16.0
hf-xet==1.2.0
httpcore==1.0.9
httpx==0.28.1
huggingface_hub==1.1.4
humanfriendly==10.0
idna==3.11
Jinja2==3.1.6
joblib==1.5.2
jsonschema==4.25.1
jsonschema-specifications==2025.9.1
kiwisolver==1.4.9
markdown-it-py==4.0.0
MarkupSafe==3.0.3
matplotlib==3.10.7
mdurl==0.1.2
ml_dtypes==0.5.3
mpmath==1.3.0
natsort==8.4.0
networkx==3.4.2
ninja==1.13.0
nncf==2.18.0
numpy==2.2.6
nvidia-cublas-cu12==12.8.4.1
nvidia-cuda-cupti-cu12==12.8.90
nvidia-cuda-nvrtc-cu12==12.8.93
nvidia-cuda-runtime-cu12==12.8.90
nvidia-cudnn-cu12==9.10.2.21
nvidia-cufft-cu12==11.3.3.83
nvidia-cufile-cu12==1.13.1.3
nvidia-curand-cu12==10.3.9.90
nvidia-cusolver-cu12==11.7.3.90
nvidia-cusparse-cu12==12.5.8.93
nvidia-cusparselt-cu12==0.7.1
nvidia-nccl-cu12==2.27.3
nvidia-nvjitlink-cu12==12.8.93
nvidia-nvshmem-cu12==3.3.20
nvidia-nvtx-cu12==12.8.90
onnx==1.19.1
onnx-ir==0.1.12
onnxruntime==1.23.2
onnxscript @ git+https://github.com/ruro/onnxscript.git@ae22c2ff1f9816b3559f65b7019cd9f9ad4203ce
openvino-telemetry==2025.2.0
packaging==25.0
pandas==2.3.3
pillow==12.0.0
protobuf==6.33.1
psutil==7.1.3
pydot==3.0.4
Pygments==2.19.2
pymoo==0.6.1.5
pyparsing==3.2.5
python-dateutil==2.9.0.post0
pytz==2025.2
PyYAML==6.0.3
referencing==0.37.0
rich==14.2.0
rpds-py==0.29.0
safetensors==0.7.0
scikit-learn==1.7.2
scipy==1.16.3
setuptools==80.9.0
shellingham==1.5.4
six==1.17.0
sniffio==1.3.1
sympy==1.14.0
tabulate==0.9.0
threadpoolctl==3.6.0
timm==1.0.22
torch==2.8.0
torchvision==0.23.0
tqdm==4.67.1
triton==3.4.0
typer-slim==0.20.0
typing_extensions==4.15.0
tzdata==2025.2
wrapt==2.0.1
Additionally:
OS                  NixOS 25.11
Python              3.13.5
Install             PyPI
RAM                 32.00 GB
CPU                 12th Gen Intel(R) Core(TM) i9-12900HK
CUDA                12.8

Minimal Reproducible Example

import torch
import nncf
import timm

sdpa = timm.layers.attention.Attention(1, 1)
input_sample = {
    "x": torch.zeros(1, 1, 1),
}

sdpa = nncf.quantize(
    sdpa,
    calibration_dataset=nncf.Dataset([input_sample]),
    model_type=nncf.ModelType.TRANSFORMER,
    preset=nncf.QuantizationPreset.PERFORMANCE,
    target_device=nncf.TargetDevice.CPU,
)

Are you going to submit a PR?

  • [x] Yes I'd like to help by submitting a PR!

ruro avatar Nov 21 '25 07:11 ruro

After spending some time trying to figure this out, I think that SDPA actually has its own custom handling logic in _get_scope_overrides in nncf/quantization/algorithms/min_max/algorithm.py, so a separate MULTIHEAD_ATTENTION_OUTPUT pattern is not required.

For onnx, I think that the problem stems from the fact that the onnx backend just doesn't know about the Attention node, so it can be relatively easily fixed by adding

@ONNX_OPERATION_METATYPES.register()
class ONNXAttentionMetatype(ONNXOpMetatype):
    name = "AttentionOp"
    op_names = ["Attention"]
    hw_config_names = [HWConfigOpName.SCALED_DOT_PRODUCT_ATTENTION]
    target_input_ports = [0, 1]

to nncf/onnx/graph/metatypes/onnx_metatypes.py and changing scaled_dot_product_attention_metatypes in nncf/quantization/algorithms/min_max/onnx_backend.py to be

    @property
    def scaled_dot_product_attention_metatypes(self) -> list[OperatorMetatype]:
        return [om.ONNXAttentionMetatype]


For torch, I think that the issue is actually slightly more complex. In my particular case (Attention module from timm), all three (q, k and v) inputs to scaled_dot_product_attention are taken from a single unbind operator.

This means that when convert_to_nncf_graph in nncf/torch/function_hook/nncf_graph/nncf_graph_builder.py tries to create a PTNNCFGraph, the three parallel edges from unbind to scaled_dot_product_attention are gathered together into a single edge with a parallel_input_port_ids attribute. My understanding is that this is done because NNCFGraph._graph is a nx.DiGraph rather than a nx.MultiDiGraph.

This means that all three edges end up being represented by a single "real" edge, they get assigned a single ActivationQuantizationInsertionPoint and the target_input_ports logic then keeps the quantization config for this edge, because it has input_port_id=0 and parallel_input_port_ids=[1, 2] (the input_port_id matches and the parallel_input_port_ids are ignored).

I am honestly not sure, how this could be fixed. It seems to me that it's fundamentally impossible to represent the structure that we want here. We would like to quantize 2 of the 3 parallel edges, but they are represented as a single edge in the PTNNCFGraph.

Is there a reason, why NNCFGraph and by extension PTNNCFGraph uses nx.DiGraph instead of nx.MultiDiGraph for its internal ._graph?

ruro avatar Nov 21 '25 07:11 ruro

I think I figured out how to fix the torch issue too. While the PTNNCFGraph graph only has a single edge, the InsertionPointGraph and QuantizerPropagationStateGraph graphs that are generated from it can have multiple "parallel" edges, because the vertices in those graphs are determined by pairs (node_name, input_port_id) which would be different for each parallel edge.

The problem was seemingly caused by parallel_input_port_ids not being considered during the default PreHookInsertionPoint generation. I think, I was able to fix that.

Assuming I didn't miss anything silly, the PR should be ready to review / merge.

ruro avatar Nov 21 '25 09:11 ruro