parseq
parseq copied to clipboard
Support onnx
Does the model can be converted to onnx model?
Tried it just now. Was able to export to ONNX using torch.onnx.export(parseq, dummy_input, 'parseq.onnx', opset_version=14)
. Not really familiar yet with ONNX so I can't verify if the exported model works as expected (an exported TorchScript model works though, if that matters).
Tried it just now. Was able to export to ONNX using
torch.onnx.export(parseq, dummy_input, 'parseq.onnx', opset_version=14)
. Not really familiar yet with ONNX so I can't verify if the exported model works as expected (an exported TorchScript model works though, if that matters).
@baudm I can not convert to onnx, the main problem comes from load_from_checkpoint
func. The model must load from architecture, and it can be converted to onnx. The func load model from hubconf
def _load_torch_model(checkpoint_path, checkpoint, **kwargs):
import hubconf
name = os.path.basename(checkpoint_path).split('-')[0]
model_factory = getattr(hubconf, name)
model = model_factory(**kwargs)
model.load_state_dict(checkpoint)
return model
def load_from_checkpoint(checkpoint_path: str, **kwargs):
checkpoint = torch.load(checkpoint_path, map_location='cpu')
try:
model = _load_pl_checkpoint(checkpoint, **kwargs)
except KeyError:
model = _load_torch_model(checkpoint_path, checkpoint, **kwargs)
return model
can you share code ex for loading checkpoint to Model architecture?? if i can convert to onnx, I will public to test how does it work, or wok expected.
import torch
parseq = torch.hub.load('baudm/parseq', 'parseq', pretrained=True).eval()
dummy_input = torch.rand(1, 3, *parseq.hparams.img_size) # (1, 3, 32, 128) by default
# To ONNX
parseq.to_onnx('parseq.onnx', dummy_input, opset_version=14) # opset v14 or newer is required
# To TorchScript
parseq.to_torchscript('parseq-ts.pt')
@baudm model converted successfully to onnx, but can not load onnx model. I am asking the expert pytorch to resolve. If done I will give the final onnx
@baudm after some days, I had try to fix onnx, but can not. I very happy if you can give some line code example for infer model(torchscript), which you converted.
# To TorchScript
parseq.to_torchscript('parseq-ts.pt')
I get error:
model = torch.jit.load("parseq-ts.pt")
File "anaconda3/envs/dl/lib/python3.6/site-packages/torch/jit/_serialization.py", line 161, in load
cpp_module = torch._C.import_ir_module(cu, str(f), map_location, _extra_files)
RuntimeError:
Unknown builtin op: aten::reflection_pad3d.
Here are some suggestions:
aten::reflection_pad1d
aten::reflection_pad2d
The original call is:
File "/home/tupk/anaconda3/envs/ocr/lib/python3.8/site-packages/torch/nn/functional.py", line 4199
elif len(pad) == 6 and (input.dim() == 4 or input.dim() == 5):
if mode == "reflect":
return torch._C._nn.reflection_pad3d(input, pad)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
elif mode == "replicate":
return torch._C._nn.replication_pad3d(input, pad)
Serialized File "code/__torch__/torch/nn/functional.py", line 634
if _175:
if torch.eq(mode, "reflect"):
_180 = torch.reflection_pad3d(input, pad)
~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
_179 = _180
else:
'_pad' is being compiled since it was called from 'multi_head_attention_forward'
File "/home/tupk/anaconda3/envs/ocr/lib/python3.8/site-packages/torch/nn/functional.py", line 5032
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = pad(attn_mask, (0, 1))
~~~ <--- HERE
if key_padding_mask is not None:
key_padding_mask = pad(key_padding_mask, (0, 1))
Serialized File "code/__torch__/torch/nn/functional.py", line 235
if torch.__isnot__(attn_mask0, None):
attn_mask6 = unchecked_cast(Tensor, attn_mask0)
attn_mask7 = __torch__.torch.nn.functional._pad(attn_mask6, [0, 1], "constant", 0., )
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
attn_mask5 : Optional[Tensor] = attn_mask7
else:
'multi_head_attention_forward' is being compiled since it was called from 'MultiheadAttention.forward'
Serialized File "code/__torch__/torch/nn/modules/activation.py", line 39
need_weights: bool=True,
attn_mask: Optional[Tensor]=None) -> Tuple[Tensor, Optional[Tensor]]:
_1 = __torch__.torch.nn.functional.multi_head_attention_forward
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
_2 = annotate(List[Tensor], [])
_3 = torch.append(_2, torch.transpose(query, 1, 0))
@baudm I have a similar issue with loading the converted ONNX model. I am able to successfully convert the model to ONNX, but when I try to load and check if the model is well-formed I get the error.
import torch
import onnx
# Load PyTorch model
parseq = torch.hub.load('baudm/parseq', 'parseq', pretrained=True).eval()
dummy_input = torch.rand(1, 3, *parseq.hparams.img_size)
# Convert to ONNX
parseq.to_onnx('pairseq.onnx', dummy_input, opset_version=14)
# Load the ONNX model
onnx_model = onnx.load('pairseq.onnx')
# Check ONNX model
onnx.checker.check_model(onnx_model, full_check=True)
---------------------------------------------------------------------------
InferenceError Traceback (most recent call last)
Input In [1], in <cell line: 15>()
12 onnx_model = onnx.load('pairseq.onnx')
14 # Check ONNX model
---> 15 onnx.checker.check_model(onnx_model, full_check=True)
File /opt/venv/lib/python3.8/site-packages/onnx/checker.py:108, in check_model(model, full_check)
106 C.check_model(protobuf_string)
107 if full_check:
--> 108 onnx.shape_inference.infer_shapes(model, check_type=True, strict_mode=True)
File /opt/venv/lib/python3.8/site-packages/onnx/shape_inference.py:34, in infer_shapes(model, check_type, strict_mode, data_prop)
32 if isinstance(model, (ModelProto, bytes)):
33 model_str = model if isinstance(model, bytes) else model.SerializeToString()
---> 34 inferred_model_str = C.infer_shapes(model_str, check_type, strict_mode, data_prop)
35 return onnx.load_from_string(inferred_model_str)
36 elif isinstance(model, str):
InferenceError: [ShapeInferenceError] (op_type:CumSum, node name: CumSum_2527): x typestr: T, has unsupported type: tensor(bool)
Waiting for onnx and tensorrt conversion
export onnx successful
tgt_padding_mask = (((tgt_in == self.eos_id)*2).cumsum(-1) > 0) # mask tokens beyond the first EOS token.
@mcmingchang can you elaborate more? I can convert the model to onnx but I can't use it with onnxruntime. when building the onnx model I get the following message:
C:\Users\1000\.conda\envs\parseg\lib\site-packages\torch\__init__.py:833: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! assert condition, message
C:\Users\1000\.conda\envs\parseg\lib\site-packages\timm\models\vision_transformer.py:201: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of PyTorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
D:\parseg\parseq\strhub\models\parseq\system.py:129: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! if testing and (tgt_in == self.eos_id).any(dim=-1).all():
did someone already success build and run the onnx model ? can someone share it? thanks
I am having similar issues with the ONNX, any leads on it?
@ashishpapanai maybe waiting expert to solve.
https://github.com/baudm/parseq/blob/8fa51009088da67a23b44c9c203fde52ffc549e5/strhub/models/parseq/system.py#L147
This is the offending code fragment. You can comment this out (disabling the iterative refinement branch of the code) before exporting the ONNX model. I tried it and onnx.checker.check_model(parseq, full_check=True)
succeeded.
This is the offending code fragment. You can comment this out (disabling the iterative refinement branch of the code) before exporting the ONNX model. I tried it and onnx.checker.check_model(parseq, full_check=True) succeeded.
Thank you for answering @baudm I'll try it and I'll report back when everything works fine or if there are other issues.
This is the offending code fragment. You can comment this out (disabling the iterative refinement branch of the code) before exporting the ONNX model. I tried it and
onnx.checker.check_model(parseq, full_check=True)
succeeded.
Thank you @baudm
I tried below code with parseq.refine_iters=0
, and no onnx::CumSum_3090
related errors now.
fp32_onnx_path = "parseq_tiny_fp32.onnx"
parseq.refine_iters=0
parseq.to_onnx(fp32_onnx_path, img, opset_version=14)
int8_onnx_path = "parseq_tiny_uint8.onnx"
from onnxruntime.quantization import (QuantType, quantize_dynamic)
quantize_dynamic(
model_input=fp32_onnx_path,
model_output=int8_onnx_path,
weight_type=QuantType.QUInt8
)
Thank you @baudm
I tried below code with
parseq.refine_iters=0
, and noonnx::CumSum_3090
related errors now.
@allenwu5 oh yeah, this is even better. Setting refine_iters
to 0
will make the iterative refinement branch unreachable, achieving the same effect. Will close this now and update the documentation.
UPDATE: As of commit ed3d84720c3f2a7b6c463a40492e0dad93920294, refine_iters=0
is no longer required when exporting to ONNX.
~~In summary, set refine_iters=0
when exporting to ONNX:~~
import torch
parseq = torch.hub.load('baudm/parseq', 'parseq', pretrained=True).eval()
dummy_input = torch.rand(1, 3, *parseq.hparams.img_size) # (1, 3, 32, 128) by default
# To ONNX
parseq.to_onnx('parseq.onnx', dummy_input, opset_version=14) # opset v14 or newer is required
@baudm Hi, I tried the above solution and can convert it to the onnx-model successfully. But had a problem, the output size of the onnx-model was changed. For example, the max_label_length of the base model is 25, but the output size of the onnx model was only 7
The output size was changed with different conversions How to fix this issue? Thanks in advance
I am facing a similar issue with the output shape. cc: @baudm
max_label_length
is exactly that. Autoregressive decoding will terminate once [E]
(EOS) token is generated, which means the output sequence length will be less than the maximum supported label length. If you want to have a constant sequence length, use NAR decoding (decode_ar=False
)
@baudm thank you! It's work
I can also convert the onnx-model to tensorrt format and archive the same result.
For anyone who wants to convert to tensorRT format, you should simplify onnx-model using onnx-simplifier, then convert trt-model using trtexec
tool.
The benchmark of inference time between torch, onnx-runtime and trt-model (3x32x128, bs=1, average 100 samples)
torch | onnx-runtime | tensorrt-fp32 | tensorrt-fp16 |
---|---|---|---|
0.017518 (4.1839x) | 0.015875 (3.7915x) | 0.004187 (1x) | 0.002519 |
The trt-fp32-model is 4-times faster than the torch model. The trt model was served by triton-inference-server
@baudm Thanks for the advice, the export to onnx worked now.
@huyhoang17 I'm also running the model on a triton server and I'm able to make the inference request which returns me a result that I convert back with the triton client as_numpy function, this gives me an array of [1, 7, 95]. Do you have any advice on how to extract the label and confidence scores from this array?
@dietermaes FYI: https://github.com/baudm/parseq#pretrained-models-via-torch-hub
@huyhoang17 How did you make the output dimensions equal to [1, max_label_length, 95]? @baudm I tried turning decode_ar=False, but the output dimension now is [1, 6, 95]; it would be helpful if I could make the output shape 1, 25, 95 and then print the recognised characters in postprocessing.
@ashishpapanai here is the example code, you should use both 2 params: decode_ar=False & refine_iters=0
Lib version
torch==1.12.1
from strhub.models.utils import load_from_checkpoint
# To ONNX
device = "cuda"
ckpt_path = "..."
onnx_path = "..."
img = ...
parseq = load_from_checkpoint(ckpt_path)
parseq.refine_iters = 0
parseq.decode_ar = False
parseq = parseq.to(device).eval()
parseq.to_onnx(onnx_path, img, do_constant_folding=True, opset_version=14) # opset v14 or newer is required
# check
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model, full_check=True) ==> pass
@huyhoang17 I would love to see your code for ONNX inference, I am very interested in and impressed by your speedtesting!
I am getting OpenVINO IR model compile time with AR decoding enabled as 81 Minutes, which is way too large. Is there anything in the community knowledge which I can do to optimize the model?
@huyhoang17 I can't manage to get the same values out of ONNX inference, and I'm wondering if it is because of my older pytorch version. Are you getting identical inference results between your Lightning model and the ONNX version?
Is there any way to keep AR decoding and iterative refinement as onnx format?
@baudm @huyhoang17 @ashishpapanai I have error when infer onnx. Have you ever got it??
RuntimeError: Input must be a list of dictionaries or a single numpy array for input 'input.1'.
Here my code:
target_transform = transforms.Compose([
transforms.Resize((32, 128), transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(0.5, 0.5)
])
ort_sess = ort.InferenceSession('parseq.onnx')
input_tensor = ort_sess.get_inputs()[0]
image = cv2.imread("11.png")
image = Image.fromarray(image).convert("RGB")
image = target_transform(image).unsqueeze(0).cuda()
print(input_tensor.name)
outputs = ort_sess.run(None, {input_tensor.name: image})
print(outputs)
@phamkhactu There seems to be a common method to send the input to numpy, it works for me. I would be very interested to hear if your inference results are the same between ONNX and the original lightning!
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()