keras-onnx
keras-onnx copied to clipboard
Keras2onnx support dynamic input?
from transformers import BertTokenizer, TFBertForSequenceClassification import tensorflow as tf
tokenizer = BertTokenizer.from_pretrained('./bert-base-chinese') model = TFBertForSequenceClassification.from_pretrained('./bert-base-chinese')
inputs = tokenizer("rch能保证是不是回因为环境差异带来运行错误。", return_tensors="tf") print(inputs)
logits = model(inputs) print(logits)
import keras2onnx onnx_model = keras2onnx.convert_keras(model, model.name) output_model_path = "chinese_roberta_l-12_H-768.onnx" keras2onnx.save_model(onnx_model, output_model_path)
I find keras2onnx do not have interface to support dynamic input and dynamic input length, like pytorch to onnx can specify column name and length: def export_onnx_model(args, model, tokenizer, onnx_model_path): with torch.no_grad(): inputs = {'input_ids': torch.ones(1,128, dtype=torch.int64), 'attention_mask': torch.ones(1,128, dtype=torch.int64), 'token_type_ids': torch.ones(1,128, dtype=torch.int64)} outputs = model(**inputs)
symbolic_names = {0: 'batch_size', 1: 'max_seq_len'}
torch.onnx.export(model, # model being run
(inputs['input_ids'], # model input (or a tuple for multiple inputs)
inputs['attention_mask'],
inputs['token_type_ids']), # model input (or a tuple for multiple inputs)
onnx_model_path, # where to save the model (can be a file or file-like object)
opset_version=11, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input_ids', # the model's input names
'input_mask',
'segment_ids'],
output_names=['output'], # the model's output names
dynamic_axes={'input_ids': symbolic_names, # variable length axes
'input_mask' : symbolic_names,
'segment_ids' : symbolic_names})
logger.info("ONNX Model exported to {0}".format(onnx_model_path))
I do have the same issue, installing from pypi.
@jianqianzhou: If you have the model saved somewhere, you can add the from_tf=True argument to the pytorch version and the transformers library will automatically convert the tf model to pytorsch.