ncnn icon indicating copy to clipboard operation
ncnn copied to clipboard

Gemm之后shape异常

Open Momenta-IPO opened this issue 2 years ago • 1 comments

我有一个onnx模型,其中有若干gemm算子,转换成ncnn推理的时候在第一个gemm算子就出现了错误,预期的输出是shape为1x16的矩阵(如下图) image 但是在ncnn推理的时候这个gemm输出的shape是4x16的,我用的pyncnn,版本是20231027,可以帮看看这个问题吗,模型发你qq邮箱

Momenta-IPO avatar Nov 27 '23 10:11 Momenta-IPO

    import onnx
    from onnx import helper
    import sys
    # 加载ONNX模型
    model = onnx.load(sys.argv[1])
    # Define lists to keep track of new nodes and nodes to remove
    nodes_to_add = []
    nodes_to_remove = []

    # Loop through the nodes in the graph to find Gemm nodes
    for node in model.graph.node:
        if node.op_type == 'Gemm':
            # Extract the node's input and output names
            A = node.input[0]
            B = node.input[1]
            C = node.input[2]
            Y = node.output[0]

            # Create a new MatMul node
            matmul_output = Y + '_matmul_output'
            matmul_node = helper.make_node(
                'MatMul',
                inputs=[A, B],
                outputs=[matmul_output]
            )

            # Create a new Add node
            add_node = helper.make_node(
                'Add',
                inputs=[matmul_output, C],
                outputs=[Y]
            )

            # Add the new nodes to the list for later insertion
            nodes_to_add.append((node, matmul_node, add_node))

            # Mark the current Gemm node for removal
            nodes_to_remove.append(node)

    # Remove the old Gemm nodes and add the new MatMul and Add nodes
    for gemm_node in nodes_to_remove:
        model.graph.node.remove(gemm_node)
        
    for _, matmul_node, add_node in nodes_to_add:
        model.graph.node.append(matmul_node)
        model.graph.node.append(add_node)
    # 保存修改后的模型
    onnx.save(model, 'modified_model.onnx')

已解决,把原始的gemm算子拆成matmul+add即可,注意 我的原始onnx模型里面alpha和beta都是1

Momenta-IPO avatar Nov 27 '23 12:11 Momenta-IPO

针对onnx模型转换的各种问题,推荐使用最新的pnnx工具转换到ncnn In view of various problems in onnx model conversion, it is recommended to use the latest pnnx tool to convert your model to ncnn

pip install pnnx
pnnx model.onnx inputshape=[1,3,224,224]

详细参考文档 Detailed reference documentation https://github.com/pnnx/pnnx https://github.com/Tencent/ncnn/wiki/use-ncnn-with-pytorch-or-onnx#how-to-use-pnnx

nihui avatar Aug 05 '24 09:08 nihui