parseq
parseq copied to clipboard
Unable to get eos_id properly after converting to torchscript
I have converted the parseq.pt
torch script model parseq.torchscript
after basic testing I got issue on some images, model do not return correct eos_idx in C++ using parseq.torchscript
model.
In python parseq.pt
model
# Load image and prepare for input
image = Image.open(fname).convert('RGB')
image = img_transform(image).unsqueeze(0).to(args.device)
p = model(image).softmax(-1)
pred, p = model.tokenizer.decode(p)
print(ids)
print(probs)
print(f'{fname}: {pred[0]}')
output
tensor([ 6, 76, 1, 1, 9, 3, 4, 6, 7, 10, 2, 1, 1, 4, 0, 4, 3, 2,
4, 4, 7, 10, 1, 4, 4, 0])
tensor([0.9991, 0.8549, 0.9999, 0.9998, 0.9997, 0.9998, 0.9993, 0.9999, 0.9997,
0.9999, 0.9996, 0.9997, 0.9999, 0.9988, 0.9986, 0.3257, 0.2137, 0.2015,
0.3253, 0.4497, 0.2691, 0.7619, 0.2822, 0.4551, 0.7156, 0.8741])
cropped.jpeg: 5.008235691003
Converted parseq.pt
to parseq.torchscript
as below:
dummy_input = torch.rand(1, 3, 32, 128) # (1, 3, 32, 128) by default
traced_script_module = torch.jit.trace(model, dummy_input)
traced_script_module.save("parseq.torchscript")
In C++
......
......
tensor_image = tensor_image.toType(c10::kFloat).div(255);
tensor_image = transpose(tensor_image, {(2), (0), (1)});
tensor_image.unsqueeze_(0);
std::cout << "input shape : " << tensor_image.sizes() << std::endl;
std::vector<torch::jit::IValue> inputs;
inputs.push_back(tensor_image);
at::Tensor output = module.forward(inputs).toTensor();
output.softmax(-1);
std::tuple<at::Tensor, at::Tensor> probs_ids = output.max(-1);
std::cout << "ids : " << std::get<1>(probs_ids) << std::endl;
std::cout << "probs : " << std::get<0>(probs_ids) << std::endl;
std::string word;
at::Tensor ids = std::get<1>(probs_ids);
for (int c = 0; c < ids.sizes()[1]; c++)
{
int id = ids[0][c].item<int>();
if (id == 0)
{
break;
}
word += char_set[id - 1];
}
std::cout<< word << std::endl;
Output
ids : Columns 1 to 20 6 76 1 1 9 3 4 6 7 10 1 1 1 4 4 4 4 4 4 4
Columns 21 to 26 4 10 1 4 4 0
[ CPULongType{1,26} ]
probs : Columns 1 to 8 14.4657 11.2162 17.1043 17.0559 15.4707 15.9846 15.6380 15.6537
Columns 9 to 16 15.2115 13.4674 11.0046 13.6310 12.1103 13.6405 8.6431 6.6635
Columns 17 to 24 6.6002 5.9998 8.9770 10.0183 9.4945 11.0694 8.3291 8.9739
Columns 25 to 26 9.1511 8.8696
[ CPUFloatType{1,26} ]
5-00823569000333333339033
I would reckon that the difference in outputs is due to differences in the internal data representations. Can't really say, since I'm not familiar with the implementation of Torch in both Python and C++. You'd also notice the same difference in outputs if you run the exact same model weights on different hardware and/or software stacks.
I have tried torchscript model
on python got same output as in C++
model = jit.load("parseq.torchscript")
model.eval()
img_transform = SceneTextDataModule.get_transform((32, 128))
image = Image.open('cropped.jpeg').convert('RGB')
image = img_transform(image).unsqueeze(0).to('cpu')
p = model(image)
p.softmax(-1)
pred, probs = decode(p)
print(f'{pred[0]}')
Output
tensor([ 6, 76, 1, 1, 9, 3, 4, 6, 7, 10, 1, 1, 1, 4, 4, 4, 3, 3,
4, 4, 4, 10, 2, 4, 4, 0])
tensor([14.4258, 11.3851, 17.1263, 17.0329, 15.8301, 15.8701, 15.8736, 16.0863,
15.7481, 14.2250, 10.7678, 13.9986, 12.8545, 14.0670, 8.8345, 6.2844,
6.2651, 5.6718, 8.0785, 9.5057, 9.3010, 11.1427, 8.4386, 9.0718,
9.1752, 8.5811], grad_fn=<MaxBackward0>)
5-00823569000333223339133
I was doing mistake here p.softmax(-1)
it should be p = p.softmax(-1)
and in c++ it should be output = output.softmax(-1);
After trying different version of Pytorch in C++ got correct output on torch version 1.13.1 but still issue in inference time that is almost 7 times high in C++ as compared to python. Issue mentioned here
I found the torch version of parseq which is able to convert onnx and tensorrt too https://github.com/bharatsubedi/PARseq_torch
I found the torch version of parseq which is able to convert onnx and tensorrt too https://github.com/bharatsubedi/PARseq_torch
The reference implementation is PyTorch. PyTorch-Lightning is only used for training and has nothing to do with the model during inference. The model code here is designed to be easily importable and usable in other projects—there's no need to remove the PL references because the model instance would work just like any other vanilla nn.Module
.
I would reckon that the difference in outputs is due to differences in the internal data representations. Can't really say, since I'm not familiar with the implementation of Torch in both Python and C++. You'd also notice the same difference in outputs if you run the exact same model weights on different hardware and/or software stacks.
Maby a little late, but torch version is not enough for the solution here. I think the main problem comes from your random dummy_input when converting:
dummy_input = torch.rand(1, 3, 32, 128) # (1, 3, 32, 128) by default
traced_script_module = torch.jit.trace(model, dummy_input)
traced_script_module.save("parseq.torchscript")
Forward function of Parseq using conditions and loops, dummy_input plays an important role for generating final torchscript code of the torschscript model. Random dummy_input means the condition and loops in Decoding process of Parseq will be random. Here my solutions:
Solution 1: Using special dummy_input:
from PIL import Image
import torch
from strhub.data.module import SceneTextDataModule
img_transform = SceneTextDataModule.get_transform((32, 128))
img_path = "/path/to/special/image"
image = Image.open(img_path).convert('RGB')
dummy_input = img_transform(image).unsqueeze(0)
parseq = torch.hub.load('baudm/parseq', 'parseq', pretrained=True).eval()
torch_jit = parseq.to_torchscript(file_path='./torchscript_model.ptl', method='trace', example_inputs=dummy_input)
Assuming pretrained parsed has max sequence length of predicton is 26 (including eos), we need to find an image which has 25 characters and pretrained parseq needs to predict this image correctly.
Solution 2: Remove all conditions and loops in forward function of Parseq
Not recommended, but you can try remove them, using functions with data independent instead, then use new script to load model and do the converting using trace method.
Unfortunately, these two solutions make the model decodes 26 times if max_label_length = 26 instead of early stopping after the eos token, so the inference time will be much longer for most cases
@gioivuathoi is correct. The forward method of the PARSeq model is dynamic in the sense that the actual code path is dependent on the input, especially for the autoregressive decoding mode. When converting to TorchScript, do not use the trace mode.
You can use the built-in method available in Lightning:
# create the model
parseq = torch.hub.load('baudm/parseq', 'parseq', pretrained=True).eval()
script = parseq.to_torchscript()
# save for use in production environment
torch.jit.save(script, "parseq.pt")
See: https://lightning.ai/docs/pytorch/stable/deploy/production_advanced_2.html
In addition to Solution 2 presented above, you can opt to use the non-autoregressive decoding mode (decode_ar=False
). This code path is static and does not depend on the input.