parseq icon indicating copy to clipboard operation
parseq copied to clipboard

Unable to get eos_id properly after converting to torchscript

Open EMHussain opened this issue 2 years ago • 6 comments

I have converted the parseq.pt torch script model parseq.torchscript after basic testing I got issue on some images, model do not return correct eos_idx in C++ using parseq.torchscript model.

In python parseq.pt model

# Load image and prepare for input
        image = Image.open(fname).convert('RGB')
        image = img_transform(image).unsqueeze(0).to(args.device)

        p = model(image).softmax(-1)
        pred, p = model.tokenizer.decode(p)
         print(ids)
         print(probs)
        print(f'{fname}: {pred[0]}')

output

tensor([ 6, 76,  1,  1,  9,  3,  4,  6,  7, 10,  2,  1,  1,  4,  0,  4,  3,  2,
         4,  4,  7, 10,  1,  4,  4,  0])
tensor([0.9991, 0.8549, 0.9999, 0.9998, 0.9997, 0.9998, 0.9993, 0.9999, 0.9997,
        0.9999, 0.9996, 0.9997, 0.9999, 0.9988, 0.9986, 0.3257, 0.2137, 0.2015,
        0.3253, 0.4497, 0.2691, 0.7619, 0.2822, 0.4551, 0.7156, 0.8741])
cropped.jpeg: 5.008235691003

Converted parseq.pt to parseq.torchscript as below:

dummy_input = torch.rand(1, 3, 32, 128)  # (1, 3, 32, 128) by default
traced_script_module = torch.jit.trace(model, dummy_input)
traced_script_module.save("parseq.torchscript")

In C++

 ......
 ......
 tensor_image = tensor_image.toType(c10::kFloat).div(255);
 tensor_image = transpose(tensor_image, {(2), (0), (1)});
 tensor_image.unsqueeze_(0);
 std::cout << "input shape : " << tensor_image.sizes() << std::endl;
 std::vector<torch::jit::IValue> inputs;
 inputs.push_back(tensor_image);
 at::Tensor output = module.forward(inputs).toTensor();
 
 output.softmax(-1);
 std::tuple<at::Tensor, at::Tensor> probs_ids = output.max(-1);
 std::cout << "ids : " << std::get<1>(probs_ids) << std::endl;
 std::cout << "probs : " << std::get<0>(probs_ids) << std::endl;
 std::string word;
 at::Tensor ids = std::get<1>(probs_ids);
 for (int c = 0; c < ids.sizes()[1]; c++)
 {
     int id = ids[0][c].item<int>();
     if (id == 0)
     {
         break;
     }
     word += char_set[id - 1];
 }
 std::cout<< word << std::endl;

Output

ids : Columns 1 to 20  6  76   1   1   9   3   4   6   7  10   1   1   1   4   4   4   4   4   4   4

Columns 21 to 26  4  10   1   4   4   0
[ CPULongType{1,26} ]
probs : Columns 1 to 8 14.4657  11.2162  17.1043  17.0559  15.4707  15.9846  15.6380  15.6537

Columns 9 to 16 15.2115  13.4674  11.0046  13.6310  12.1103  13.6405   8.6431   6.6635

Columns 17 to 24  6.6002   5.9998   8.9770  10.0183   9.4945  11.0694   8.3291   8.9739

Columns 25 to 26  9.1511   8.8696
[ CPUFloatType{1,26} ]
5-00823569000333333339033

EMHussain avatar Feb 10 '23 14:02 EMHussain

I would reckon that the difference in outputs is due to differences in the internal data representations. Can't really say, since I'm not familiar with the implementation of Torch in both Python and C++. You'd also notice the same difference in outputs if you run the exact same model weights on different hardware and/or software stacks.

baudm avatar Feb 11 '23 07:02 baudm

I have tried torchscript model on python got same output as in C++

  model = jit.load("parseq.torchscript")
  model.eval()
  img_transform = SceneTextDataModule.get_transform((32, 128))
  image = Image.open('cropped.jpeg').convert('RGB')
  image = img_transform(image).unsqueeze(0).to('cpu')
  p = model(image)
  p.softmax(-1)
  pred, probs = decode(p)
  print(f'{pred[0]}')

Output

tensor([ 6, 76,  1,  1,  9,  3,  4,  6,  7, 10,  1,  1,  1,  4,  4,  4,  3,  3,
         4,  4,  4, 10,  2,  4,  4,  0])
tensor([14.4258, 11.3851, 17.1263, 17.0329, 15.8301, 15.8701, 15.8736, 16.0863,
        15.7481, 14.2250, 10.7678, 13.9986, 12.8545, 14.0670,  8.8345,  6.2844,
         6.2651,  5.6718,  8.0785,  9.5057,  9.3010, 11.1427,  8.4386,  9.0718,
         9.1752,  8.5811], grad_fn=<MaxBackward0>)
5-00823569000333223339133

EMHussain avatar Feb 14 '23 09:02 EMHussain

I was doing mistake here p.softmax(-1) it should be p = p.softmax(-1) and in c++ it should be output = output.softmax(-1);

EMHussain avatar Feb 14 '23 12:02 EMHussain

After trying different version of Pytorch in C++ got correct output on torch version 1.13.1 but still issue in inference time that is almost 7 times high in C++ as compared to python. Issue mentioned here

EMHussain avatar Feb 16 '23 11:02 EMHussain

I found the torch version of parseq which is able to convert onnx and tensorrt too https://github.com/bharatsubedi/PARseq_torch

WongVi avatar Mar 06 '23 07:03 WongVi

I found the torch version of parseq which is able to convert onnx and tensorrt too https://github.com/bharatsubedi/PARseq_torch

The reference implementation is PyTorch. PyTorch-Lightning is only used for training and has nothing to do with the model during inference. The model code here is designed to be easily importable and usable in other projects—there's no need to remove the PL references because the model instance would work just like any other vanilla nn.Module.

baudm avatar Apr 28 '23 19:04 baudm

I would reckon that the difference in outputs is due to differences in the internal data representations. Can't really say, since I'm not familiar with the implementation of Torch in both Python and C++. You'd also notice the same difference in outputs if you run the exact same model weights on different hardware and/or software stacks.

Maby a little late, but torch version is not enough for the solution here. I think the main problem comes from your random dummy_input when converting:

dummy_input = torch.rand(1, 3, 32, 128)  # (1, 3, 32, 128) by default
traced_script_module = torch.jit.trace(model, dummy_input)
traced_script_module.save("parseq.torchscript")

Forward function of Parseq using conditions and loops, dummy_input plays an important role for generating final torchscript code of the torschscript model. Random dummy_input means the condition and loops in Decoding process of Parseq will be random. Here my solutions:

Solution 1: Using special dummy_input:

from PIL import Image
import torch
from strhub.data.module import SceneTextDataModule
img_transform = SceneTextDataModule.get_transform((32, 128))
img_path = "/path/to/special/image"
image = Image.open(img_path).convert('RGB')
dummy_input = img_transform(image).unsqueeze(0)
parseq = torch.hub.load('baudm/parseq', 'parseq', pretrained=True).eval()
torch_jit = parseq.to_torchscript(file_path='./torchscript_model.ptl', method='trace', example_inputs=dummy_input)

Assuming pretrained parsed has max sequence length of predicton is 26 (including eos), we need to find an image which has 25 characters and pretrained parseq needs to predict this image correctly.

Solution 2: Remove all conditions and loops in forward function of Parseq

Not recommended, but you can try remove them, using functions with data independent instead, then use new script to load model and do the converting using trace method.

Unfortunately, these two solutions make the model decodes 26 times if max_label_length = 26 instead of early stopping after the eos token, so the inference time will be much longer for most cases

gioivuathoi avatar Mar 26 '24 22:03 gioivuathoi

@gioivuathoi is correct. The forward method of the PARSeq model is dynamic in the sense that the actual code path is dependent on the input, especially for the autoregressive decoding mode. When converting to TorchScript, do not use the trace mode.

You can use the built-in method available in Lightning:

# create the model
parseq = torch.hub.load('baudm/parseq', 'parseq', pretrained=True).eval()
script = parseq.to_torchscript()

# save for use in production environment
torch.jit.save(script, "parseq.pt")

See: https://lightning.ai/docs/pytorch/stable/deploy/production_advanced_2.html

In addition to Solution 2 presented above, you can opt to use the non-autoregressive decoding mode (decode_ar=False). This code path is static and does not depend on the input.

baudm avatar Mar 27 '24 03:03 baudm