Paddle2ONNX
Paddle2ONNX copied to clipboard
[NVIDIA] Need to support custom op mapping through python functions
Please fill in the information below so that we can solve the problem quickly, Thanks !
Describe the bug A clear and concise description of what the bug is.
After training a FP8 model using Transformer Engine, to deploy the model using TRT, the paddle model needs to convert to ONNX model. Transformer Engine uses Paddle custom op to implement fast kernels for FP8 GEMM. Therefore, we need a python API to map custom ops to ONNX nodes.
For example, there is a custom op cast_to_fp8
, which has the following inputs: input, amax, scale_inv, idx, otype
.
We can define a mapper function, create ONNX nodes and set attributes / inputs as we want:
@symbolic_helper.parse_args("v", "v", "v", "fs", "i", "i")
def onnx_cast_to_fp8(g, input, amax, scale_inv, idx, otype):
scale = g.op("Constant", value_t=scale_inv[fp8_tensor])
q_op = g.op(
make_op_name("TRT_FP8QuantizeLinear"), input, scale).setType(
input.type().with_dtype(torch.uint8).with_sizes(output_shape))
return q_op
Followed by registry:
paddle2onnx.register_custom_op_symbolic('cast_to_fp8', onnx_cast_to_fp8)
The request is very much similar to PyTorch's register_custom_op_symbolic, (reference link).
Informations (please complete the following information):
- Inference engine for deployment:
- Why convert to onnx:
- Paddle2ONNX Version:
- Email/Wechat/Phone:
Screenshots
Additional context