deep-text-recognition-benchmark icon indicating copy to clipboard operation
deep-text-recognition-benchmark copied to clipboard

Conversion to ONNX

Open Darshcg opened this issue 4 years ago • 4 comments

Hi,

I am trying to convert the PyTorch model to ONNX, but I am facing an issue as RuntimeError: Exporting the operator grid_sampler to ONNX opset version 12 is not supported.

May I know how to resolve this issue?

Thanks and Regards, Darshan C G

Darshcg avatar Jun 05 '21 12:06 Darshcg

Hi,

This operation is literally not supported in the pytorch onnx export for a long time. But you can mimic this operation by set of other operations, like here. Here is an example of how I did:

            if not torch.onnx.is_in_onnx_export():
                batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border', align_corners=True)
            else:
                # workwround for export to onnx
                # see here for details: https://github.com/open-mmlab/mmcv/pull/953/
                n, c, h, w = batch_I.shape
                gn, gh, gw, _ = build_P_prime_reshape.shape
                assert n == gn

                x = build_P_prime_reshape[:, :, :, 0]
                y = build_P_prime_reshape[:, :, :, 1]

                x = ((x + 1) / 2) * (w - 1)
                y = ((y + 1) / 2) * (h - 1)

                x = x.view(n, -1)
                y = y.view(n, -1)

                x0 = torch.floor(x).long()
                y0 = torch.floor(y).long()
                x1 = x0 + 1
                y1 = y0 + 1

                wa = ((x1 - x) * (y1 - y)).unsqueeze(1)
                wb = ((x1 - x) * (y - y0)).unsqueeze(1)
                wc = ((x - x0) * (y1 - y)).unsqueeze(1)
                wd = ((x - x0) * (y - y0)).unsqueeze(1)

                # Apply default for grid_sample function zero padding
                im_padded = F.pad(batch_I, pad=[1, 1, 1, 1], mode='replicate')
                padded_h = h + 2
                padded_w = w + 2
                # save points positions after padding
                x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1

                # Clip coordinates to padded image size
                x0 = torch.where(x0 < 0, torch.tensor(0), x0)
                x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1), x0)
                x1 = torch.where(x1 < 0, torch.tensor(0), x1)
                x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1), x1)
                y0 = torch.where(y0 < 0, torch.tensor(0), y0)
                y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1), y0)
                y1 = torch.where(y1 < 0, torch.tensor(0), y1)
                y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1), y1)

                im_padded = im_padded.view(n, c, -1)

                x0_y0 = (x0 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
                x0_y1 = (x0 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
                x1_y0 = (x1 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
                x1_y1 = (x1 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)

                Ia = torch.gather(im_padded, 2, x0_y0)
                Ib = torch.gather(im_padded, 2, x0_y1)
                Ic = torch.gather(im_padded, 2, x1_y0)
                Id = torch.gather(im_padded, 2, x1_y1)

                batch_I_r = (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw)

morkovka1337 avatar Jul 26 '21 07:07 morkovka1337

Hi,

This operation is literally not supported in the pytorch onnx export for a long time. But you can mimic this operation by set of other operations, like here. Here is an example of how I did:

            if not torch.onnx.is_in_onnx_export():
                batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border', align_corners=True)
            else:
                # workwround for export to onnx
                # see here for details: https://github.com/open-mmlab/mmcv/pull/953/
                n, c, h, w = batch_I.shape
                gn, gh, gw, _ = build_P_prime_reshape.shape
                assert n == gn

                x = build_P_prime_reshape[:, :, :, 0]
                y = build_P_prime_reshape[:, :, :, 1]

                x = ((x + 1) / 2) * (w - 1)
                y = ((y + 1) / 2) * (h - 1)

                x = x.view(n, -1)
                y = y.view(n, -1)

                x0 = torch.floor(x).long()
                y0 = torch.floor(y).long()
                x1 = x0 + 1
                y1 = y0 + 1

                wa = ((x1 - x) * (y1 - y)).unsqueeze(1)
                wb = ((x1 - x) * (y - y0)).unsqueeze(1)
                wc = ((x - x0) * (y1 - y)).unsqueeze(1)
                wd = ((x - x0) * (y - y0)).unsqueeze(1)

                # Apply default for grid_sample function zero padding
                im_padded = F.pad(batch_I, pad=[1, 1, 1, 1], mode='replicate')
                padded_h = h + 2
                padded_w = w + 2
                # save points positions after padding
                x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1

                # Clip coordinates to padded image size
                x0 = torch.where(x0 < 0, torch.tensor(0), x0)
                x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1), x0)
                x1 = torch.where(x1 < 0, torch.tensor(0), x1)
                x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1), x1)
                y0 = torch.where(y0 < 0, torch.tensor(0), y0)
                y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1), y0)
                y1 = torch.where(y1 < 0, torch.tensor(0), y1)
                y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1), y1)

                im_padded = im_padded.view(n, c, -1)

                x0_y0 = (x0 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
                x0_y1 = (x0 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
                x1_y0 = (x1 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
                x1_y1 = (x1 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)

                Ia = torch.gather(im_padded, 2, x0_y0)
                Ib = torch.gather(im_padded, 2, x0_y1)
                Ic = torch.gather(im_padded, 2, x1_y0)
                Id = torch.gather(im_padded, 2, x1_y1)

                batch_I_r = (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw)

Hello @morkovka1337, where did you insert this?

rafaelagrc avatar Jun 21 '22 10:06 rafaelagrc

Hi, This operation is literally not supported in the pytorch onnx export for a long time. But you can mimic this operation by set of other operations, like here. Here is an example of how I did:

            if not torch.onnx.is_in_onnx_export():
                batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border', align_corners=True)
            else:
                # workwround for export to onnx
                # see here for details: https://github.com/open-mmlab/mmcv/pull/953/
                n, c, h, w = batch_I.shape
                gn, gh, gw, _ = build_P_prime_reshape.shape
                assert n == gn

                x = build_P_prime_reshape[:, :, :, 0]
                y = build_P_prime_reshape[:, :, :, 1]

                x = ((x + 1) / 2) * (w - 1)
                y = ((y + 1) / 2) * (h - 1)

                x = x.view(n, -1)
                y = y.view(n, -1)

                x0 = torch.floor(x).long()
                y0 = torch.floor(y).long()
                x1 = x0 + 1
                y1 = y0 + 1

                wa = ((x1 - x) * (y1 - y)).unsqueeze(1)
                wb = ((x1 - x) * (y - y0)).unsqueeze(1)
                wc = ((x - x0) * (y1 - y)).unsqueeze(1)
                wd = ((x - x0) * (y - y0)).unsqueeze(1)

                # Apply default for grid_sample function zero padding
                im_padded = F.pad(batch_I, pad=[1, 1, 1, 1], mode='replicate')
                padded_h = h + 2
                padded_w = w + 2
                # save points positions after padding
                x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1

                # Clip coordinates to padded image size
                x0 = torch.where(x0 < 0, torch.tensor(0), x0)
                x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1), x0)
                x1 = torch.where(x1 < 0, torch.tensor(0), x1)
                x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1), x1)
                y0 = torch.where(y0 < 0, torch.tensor(0), y0)
                y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1), y0)
                y1 = torch.where(y1 < 0, torch.tensor(0), y1)
                y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1), y1)

                im_padded = im_padded.view(n, c, -1)

                x0_y0 = (x0 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
                x0_y1 = (x0 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
                x1_y0 = (x1 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
                x1_y1 = (x1 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)

                Ia = torch.gather(im_padded, 2, x0_y0)
                Ib = torch.gather(im_padded, 2, x0_y1)
                Ic = torch.gather(im_padded, 2, x1_y0)
                Id = torch.gather(im_padded, 2, x1_y1)

                batch_I_r = (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw)

Hello @morkovka1337, where did you insert this?

Hi, as far as I remember, it was in the forward function of the model

morkovka1337 avatar Jun 29 '22 08:06 morkovka1337

你好

在 pytorch onnx 导出中,此操作在很长一段时间内都不受支持。但是,您可以通过一组其他操作来模拟此操作,如下所示。以下是我如何做的一个例子:

            if not torch.onnx.is_in_onnx_export():
                batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border', align_corners=True)
            else:
                # workwround for export to onnx
                # see here for details: https://github.com/open-mmlab/mmcv/pull/953/
                n, c, h, w = batch_I.shape
                gn, gh, gw, _ = build_P_prime_reshape.shape
                assert n == gn

                x = build_P_prime_reshape[:, :, :, 0]
                y = build_P_prime_reshape[:, :, :, 1]

                x = ((x + 1) / 2) * (w - 1)
                y = ((y + 1) / 2) * (h - 1)

                x = x.view(n, -1)
                y = y.view(n, -1)

                x0 = torch.floor(x).long()
                y0 = torch.floor(y).long()
                x1 = x0 + 1
                y1 = y0 + 1

                wa = ((x1 - x) * (y1 - y)).unsqueeze(1)
                wb = ((x1 - x) * (y - y0)).unsqueeze(1)
                wc = ((x - x0) * (y1 - y)).unsqueeze(1)
                wd = ((x - x0) * (y - y0)).unsqueeze(1)

                # Apply default for grid_sample function zero padding
                im_padded = F.pad(batch_I, pad=[1, 1, 1, 1], mode='replicate')
                padded_h = h + 2
                padded_w = w + 2
                # save points positions after padding
                x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1

                # Clip coordinates to padded image size
                x0 = torch.where(x0 < 0, torch.tensor(0), x0)
                x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1), x0)
                x1 = torch.where(x1 < 0, torch.tensor(0), x1)
                x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1), x1)
                y0 = torch.where(y0 < 0, torch.tensor(0), y0)
                y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1), y0)
                y1 = torch.where(y1 < 0, torch.tensor(0), y1)
                y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1), y1)

                im_padded = im_padded.view(n, c, -1)

                x0_y0 = (x0 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
                x0_y1 = (x0 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
                x1_y0 = (x1 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
                x1_y1 = (x1 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)

                Ia = torch.gather(im_padded, 2, x0_y0)
                Ib = torch.gather(im_padded, 2, x0_y1)
                Ic = torch.gather(im_padded, 2, x1_y0)
                Id = torch.gather(im_padded, 2, x1_y1)

                batch_I_r = (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw)

Hi, is there a replacement writeup for 5-D?

dingjingzhen avatar Aug 03 '22 06:08 dingjingzhen