optimizer icon indicating copy to clipboard operation
optimizer copied to clipboard

fuse_matmul_add_bias_into_gemm not working with batch size

Open erelon opened this issue 2 years ago • 2 comments

Hi,

When using fuse_matmul_add_bias_into_gemm I expect that even with batch size the layers will fuse. Apparently, this is not supported. I can't see what is the reason for this. If there is a problem with more then one batch, the fuse can happen at least when the batch size dim is 1.

Here is the example code to create this issue (heavily based on #58):

    from onnx import helper
    from onnx import checker, helper, ModelProto, TensorProto, GraphProto, NodeProto, shape_inference
    import onnxoptimizer

    matmul = helper.make_node("MatMul", ["X", "Y"], ["Z"])
    add = helper.make_node("Add", ["Z", "B"], ["A"])
    graph = helper.make_graph(
        [matmul, add],
        "test",
        [helper.make_tensor_value_info("X", TensorProto.FLOAT, (1, 32, 10)),

         ],
        [helper.make_tensor_value_info("A", TensorProto.FLOAT, (1, 32, 16))],
        [helper.make_tensor("B", TensorProto.FLOAT, tuple([16]), np.ones([1, 16])),
         helper.make_tensor("Y", TensorProto.FLOAT, (1, 10, 16), np.ones([1, 10, 16])), ]
    )

    model = helper.make_model(graph)
    onnx.save(model, "gg.onnx")
    optimized_model = onnxoptimizer.optimize(
        model, passes=["fuse_matmul_add_bias_into_gemm"])
    onnx.save(optimized_model, "gg1.onnx")

    print(optimized_model)

    assert len(list(optimized_model.graph.node)) == 1
    assert optimized_model.graph.node[0].op_type == "Gemm"

erelon avatar May 17 '22 13:05 erelon

@erelon The rank of Matmul input tensor must be 2 when using fuse_matmul_add_bias_into_gemm , you can find it in source https://github.com/onnx/optimizer/blob/master/onnxoptimizer/passes/fuse_matmul_add_bias_into_gemm.h#L60

HSQ79815 avatar Jul 05 '22 11:07 HSQ79815

In Gemm defination, The shape of A should be (M,K) or (K,M), and the shape of B should be (K,N) or (N,K).

HSQ79815 avatar Jul 05 '22 11:07 HSQ79815