deep-text-recognition-benchmark
deep-text-recognition-benchmark copied to clipboard
ONNX
Hi thank you for your great work . Would you please add code for converting pth to onnx ?
there is a good guide for converting torch to onnx. https://pytorch.org/docs/stable/onnx.html
Thank you . I have successfully converted pth to onnx . I get this error when running this ineferce code . Would you please help me to fix it ? Error
Traceback (most recent call last): File "inference_onnx.py", line 45, in pred_str = converter.decode(pred_index, length_for_pred) File "E:\new2\text-recognition-wii-main\preprocess\converter.py", line 63, in decode text = ''.join([self.character[i] for i in text_index[index, :]]) File "E:\new2\text-recognition-wii-main\preprocess\converter.py", line 63, in text = ''.join([self.character[i] for i in text_index[index, :]]) TypeError: list indices must be integers or slices, not numpy.float32
Inference code
#%% import onnxruntime as ort from PIL import Image from torchvision import transforms import numpy as np import matplotlib.pyplot as plt import base64 import io import torch from preprocess.converter import NormalizePAD, TokenLabelConverter from models.model import Model #%%
def preprocess_img(path): data_transforms = NormalizePAD((1, 224, 224)) img = Image.open(path).convert('L') img = img.resize((224, 224), Image.BICUBIC) img = data_transforms(img) img = img.unsqueeze(0) return img
def to_numpy(tensor): return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
def predict_onnx(sess, path): input_name = sess.get_inputs()[0].name print(input_name) img = preprocess_img(path) preds = sess.run(None, {input_name: to_numpy(img)}) preds = np.squeeze(preds) return preds #%% if name == 'main': sess = ort.InferenceSession('last_model.onnx') converter = TokenLabelConverter() path = 'examples/1.jpg' preds = predict_onnx(sess, path) print(preds) #%% print(preds.shape) pred_index = preds.max(1) pred_index = pred_index.reshape(1, 25) print(pred_index) length_for_pred = np.array([25 - 1]) pred_str = converter.decode(pred_index, length_for_pred) print(pred_str) #%%