Paddle2ONNX icon indicating copy to clipboard operation
Paddle2ONNX copied to clipboard

[NVIDIA] Need to support custom op mapping through python functions

Open Tom-Zheng opened this issue 1 year ago • 0 comments

Please fill in the information below so that we can solve the problem quickly, Thanks !

Describe the bug A clear and concise description of what the bug is.

After training a FP8 model using Transformer Engine, to deploy the model using TRT, the paddle model needs to convert to ONNX model. Transformer Engine uses Paddle custom op to implement fast kernels for FP8 GEMM. Therefore, we need a python API to map custom ops to ONNX nodes.

For example, there is a custom op cast_to_fp8, which has the following inputs: input, amax, scale_inv, idx, otype.

We can define a mapper function, create ONNX nodes and set attributes / inputs as we want:

@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i")
def onnx_cast_to_fp8(g, input, amax, scale_inv, idx, otype):
    scale = g.op("Constant", value_t=scale_inv[fp8_tensor])
    q_op = g.op(
        make_op_name("TRT_FP8QuantizeLinear"), input, scale).setType(
            input.type().with_dtype(torch.uint8).with_sizes(output_shape))
    return q_op

Followed by registry:

paddle2onnx.register_custom_op_symbolic('cast_to_fp8', onnx_cast_to_fp8)

The request is very much similar to PyTorch's register_custom_op_symbolic, (reference link).

Informations (please complete the following information):

  • Inference engine for deployment:
  • Why convert to onnx:
  • Paddle2ONNX Version:
  • Email/Wechat/Phone:

Screenshots

Additional context

Tom-Zheng avatar Aug 24 '23 04:08 Tom-Zheng