FunASR
FunASR copied to clipboard
为什么导出paraformer streaming模型只能导出encoder和decoder部分,不能导出predicator部分
def export_rebuild_model(model, **kwargs): # self.device = kwargs.get("device") is_onnx = kwargs.get("type", "onnx") == "onnx" encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export") model.encoder = encoder_class(model.encoder, onnx=is_onnx)
predictor_class = tables.predictor_classes.get(kwargs["predictor"] + "Export")
model.predictor = predictor_class(model.predictor, onnx=is_onnx)
if kwargs["decoder"] == "ParaformerSANMDecoder":
kwargs["decoder"] = "ParaformerSANMDecoderOnline"
decoder_class = tables.decoder_classes.get(kwargs["decoder"] + "Export")
model.decoder = decoder_class(model.decoder, onnx=is_onnx)
from funasr.utils.torch_function import sequence_mask
model.make_pad_mask = sequence_mask(max_seq_len=None, flip=False)
import copy
import types
encoder_model = copy.copy(model)
decoder_model = copy.copy(model)
# encoder
encoder_model.forward = types.MethodType(export_encoder_forward, encoder_model)
encoder_model.export_dummy_inputs = types.MethodType(export_encoder_dummy_inputs, encoder_model)
encoder_model.export_input_names = types.MethodType(export_encoder_input_names, encoder_model)
encoder_model.export_output_names = types.MethodType(export_encoder_output_names, encoder_model)
encoder_model.export_dynamic_axes = types.MethodType(export_encoder_dynamic_axes, encoder_model)
encoder_model.export_name = "model" # types.MethodType(export_encoder_name, encoder_model)
# decoder
decoder_model.forward = types.MethodType(export_decoder_forward, decoder_model)
decoder_model.export_dummy_inputs = types.MethodType(export_decoder_dummy_inputs, decoder_model)
decoder_model.export_input_names = types.MethodType(export_decoder_input_names, decoder_model)
decoder_model.export_output_names = types.MethodType(export_decoder_output_names, decoder_model)
decoder_model.export_dynamic_axes = types.MethodType(export_decoder_dynamic_axes, decoder_model)
decoder_model.export_name = "decoder" # types.MethodType(export_decoder_name, decoder_model)
return encoder_model, decoder_model
使用官方教程from funasr import AutoModel
model = AutoModel(model="paraformer")
res = model.export(quantize=False)的导出导出代码,可以导出两个onnx文件,model.onnx和decoder.onnx,其中model.onnx查看输入输出后发现是encoder部分,decoder.onnx是decoder部分,请问怎么导出完整的paraformer streaming onnx模型呢