sparseml icon indicating copy to clipboard operation
sparseml copied to clipboard

Models with loops in their graph can't be converted to DeepSparse after QAT

Open clementpoiret opened this issue 1 year ago • 3 comments

Describe the bug

I train a ViT which has an intermediary output, which is then sent back into the network to modulate the activations as in a feedback loop. Unfortunately, the convert_qat operation is broken for MatMulAddToMatMulIntegerAddCastMul, because it wants to delete nodes that do not exist, and I don't understand why.

Expected behavior

Save the quantized model :)

Environment Include all relevant environment information:

  1. OS [e.g. Ubuntu 18.04]: Linux
  2. Python version [e.g. 3.7]: 3.10
  3. SparseML version or commit hash [e.g. 0.1.0, f7245c8]: 1.6.1
  4. ML framework version(s) [e.g. torch 1.7.1]: PyTorch 2.1.0
  5. Other Python package versions [e.g. SparseZoo, DeepSparse, numpy, ONNX]:
  6. Other relevant environment information [e.g. hardware, CUDA version]:

To Reproduce Exact steps to reproduce the behavior:

I also tried to use a ModuleExporter with convert_qat, the behavior is the same. Initial model save (attached output):

        exporter = TorchToONNX(
            sample_batch=(dummy_image, dummy_context),
            input_names=["image", "context"],
            output_names=["bodysizes", "predicted_context"],
            opset=17,
        )
        exporter.export(model,
                        file_path="/workspace/models/onnx/model_opt.onnx")

Then QAT convertion:

from sparseml.exporters.onnx_to_deepsparse import ONNXToDeepsparse

model_path = "/workspace/models/onnx/model_opt.onnx"

exporter = ONNXToDeepsparse()

exporter.export(model_path, "/workspace/models/onnx/model_opt-ds.onnx")

Errors

>>> exporter.export(model_path, "/workspace/models/onnx/model_opt-ds.onnx")
2024-01-26 10:23:17 sparseml.exporters.transforms.onnx_transform INFO     [ConstantsToInitializers] Transformed 523 matches
[10:23:17] INFO     [ConstantsToInitializers] Transformed 523 matches                                                          onnx_transform.py:97
2024-01-26 10:23:18 sparseml.exporters.transforms.onnx_transform INFO     [FoldIdentityInitializers] Transformed 0 matches
[10:23:18] INFO     [FoldIdentityInitializers] Transformed 0 matches                                                           onnx_transform.py:97
2024-01-26 10:23:19 sparseml.exporters.transforms.onnx_transform INFO     [InitializersToUint8] Transformed 92 matches
[10:23:19] INFO     [InitializersToUint8] Transformed 92 matches                                                               onnx_transform.py:97
2024-01-26 10:23:20 sparseml.exporters.transforms.onnx_transform INFO     [FlattenQParams] Transformed 0 matches
[10:23:20] INFO     [FlattenQParams] Transformed 0 matches                                                                     onnx_transform.py:97
2024-01-26 10:23:21 sparseml.exporters.transforms.onnx_transform INFO     [FoldConvDivBn] Transformed 0 matches
[10:23:21] INFO     [FoldConvDivBn] Transformed 0 matches                                                                      onnx_transform.py:97
2024-01-26 10:23:21 sparseml.exporters.transforms.onnx_transform INFO     [DeleteRepeatedQdq] Transformed 215 matches
           INFO     [DeleteRepeatedQdq] Transformed 215 matches                                                                onnx_transform.py:97
2024-01-26 10:23:22 sparseml.exporters.transforms.onnx_transform INFO     [QuantizeQATEmbedding] Transformed 0 matches
[10:23:22] INFO     [QuantizeQATEmbedding] Transformed 0 matches                                                               onnx_transform.py:97
2024-01-26 10:23:23 sparseml.exporters.transforms.onnx_transform INFO     [PropagateEmbeddingQuantization] Transformed 0 matches
[10:23:23] INFO     [PropagateEmbeddingQuantization] Transformed 0 matches                                                     onnx_transform.py:97
2024-01-26 10:23:23 sparseml.exporters.transforms.onnx_transform INFO     [PropagateDequantThroughSplit] Transformed 0 matches
           INFO     [PropagateDequantThroughSplit] Transformed 0 matches                                                       onnx_transform.py:97
2024-01-26 10:23:25 sparseml.exporters.transforms.onnx_transform INFO     [MatMulAddToMatMulIntegerAddCastMul] Transformed 95 matches
[10:23:25] INFO     [MatMulAddToMatMulIntegerAddCastMul] Transformed 95 matches                                                onnx_transform.py:97
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/clementpoiret/micromamba/envs/torch/lib/python3.10/site-packages/sparseml/exporters/onnx_to_deepsparse.py", line 128, in export
    post_transforms_model: onnx.ModelProto = self.apply(pre_transforms_model)
  File "/home/clementpoiret/micromamba/envs/torch/lib/python3.10/site-packages/sparseml/exporters/transforms/base_transform.py", line 43, in apply
    model = self.transform(model)
  File "/home/clementpoiret/micromamba/envs/torch/lib/python3.10/site-packages/sparseml/exporters/base_exporter.py", line 28, in transform
    model = transform.apply(model)
  File "/home/clementpoiret/micromamba/envs/torch/lib/python3.10/site-packages/sparseml/exporters/transforms/base_transform.py", line 44, in apply
    model = self.post_validate(model)
  File "/home/clementpoiret/micromamba/envs/torch/lib/python3.10/site-packages/sparseml/exporters/transforms/onnx_transform.py", line 102, in post_
validate
    model.graph.node.remove(node)
ValueError: Item to delete not in list

Additional context Add any other context about the problem here. Also include any relevant files.

The model I want to convert using ONNXToDeepsparse: https://ufile.io/cz95c1oo

clementpoiret avatar Jan 26 '24 09:01 clementpoiret

Hi @clementpoiret thanks for opening - are you able to set any breakpoints to see which node is trying to be deleted? Loops in graphs are an edge case that was not covered in development or testing so it is possible that we could run into an error like this if the functions local representation of the graph gets stale

bfineran avatar Feb 15 '24 19:02 bfineran

Hi @clementpoiret Let us know, per @bfineran's last inquiry, if you have further insights? Thank you!

Jeannie / Neural Magic

jeanniefinks avatar Mar 07 '24 19:03 jeanniefinks

Hi, sorry for the delay! I finally ended-up changing the architecture to simplify it and remove its cyclic part. Looking directly into the nodes gave me too many headaches :cry:

clementpoiret avatar Mar 08 '24 07:03 clementpoiret

Hello @clementpoiret Here's to architecture simplification! A few weeks have gone by; I am going to go ahead and close this issue. Feel free to re-open if you want to continue the conversation! Best wishes to you ☕ , Jeannie / Neural Magic

jeanniefinks avatar Mar 22 '24 18:03 jeanniefinks