returnn icon indicating copy to clipboard operation
returnn copied to clipboard

PyTorch ONNX export

Open albertz opened this issue 2 years ago • 98 comments

This issue is to track any aspects and issues on PyTorch (#1120) ONNX export.

  • [x] Working script for conversion (export_to_onnx.py)
  • [x] Fix issue with convolution
  • [x] Rename script to torch_export_to_onnx.py
  • [x] Check model_outputs.
  • [x] Use model_outputs dims when not specified in mark_as_output (esp in case of PT).
  • [x] The input names for sizes should be better (currently they are like "classes:size1" but should be like "data:size1"). (Fixed via #1362.)
  • [x] The batch dim (data:size0), or any scalar dyn_size_ext, should not be an input, as it is redundant. (Fixed via #1362.)
  • [x] model_outputs: Ignore dim tag equality. Check static dims, dyn dims information
  • [x] Out seq lens are actually relevant, and should not be filled with dummy values? Clarify. (https://github.com/rwth-i6/returnn/issues/1333#issuecomment-1625179457) (Fixed via #1362.)
  • [x] Working demo-rf + test case
  • [x] Test case: Also perform ONNX inference (Fixed via #1362.)
  • [ ] RF Conformer works
  • [ ] Some real-world pure PT model works
  • [ ] Running the resulting ONNX model in RASR
  • [x] Writing a sisyphus job to run ONNX export in i6_core. (Fixed via https://github.com/rwth-i6/i6_core/pull/429.)

Other things which are less important for now:

  • Avoiding TracerWarnings for some sanity checking code: I think it's not really possible in an easy way.

albertz avatar May 19 '23 10:05 albertz

The initial version is done already (export_to_onnx.py) and works for pure PT code, so I think the main work is done. We can already close the issue. But let's anyway add further aspects on this here.

albertz avatar May 19 '23 10:05 albertz

One aspect is that convolution with padding="same" is not supported by ONNX. (See also here: https://github.com/rwth-i6/returnn/issues/1120#issuecomment-1541471146) However, we can simply work around that by doing the padding manually. The code is already there for the RF, and just needs to be used. We can use torch.onnx.is_in_export_mode() to test whether we should use it.

albertz avatar May 19 '23 10:05 albertz

As a form of summary: the tool is mainly done, but there are some operators that aren't yet supported, one of which is the convolutional layer padding when specified as a string ("same" or "valid"). This prevents the demos from being run with the tool, which is one of my top priorities.

I found that it's not actually padding="same" which is not implemented, it's actually type(padding) is str what fails: a string can't be provided as padding, since I tried padding="valid" (what we had already discussed) and it failed with the same message, even when padding=0 worked fine.

I've been investigating this issue and I'm currently trying to fix it. There's some people online having the same issues and they also provide implementations of custom padding, so it shouldn't be very hard to fix. I'll try to push a version with my changes before the weekend.

Icemole avatar May 19 '23 10:05 Icemole

padding="valid" is just the same as padding=0, or not?

albertz avatar May 19 '23 10:05 albertz

Yes it is, but it seems that internally any padding given as a string triggers some parameter called aten::_convolution_mode, which isn't implemented in the ONNX conversion.

Anyway, padding="valid" is the easy case to fix, since then we only have to set padding = 0, as you said. padding="same" is a little more tricky in my view: it involves not only calculating the padding for each dimension (which is the easy part) but also ordering it properly in the result.

Also, padding is given as a tuple of maximum 3 dimensions ([[depth,] height,] width) depending on the dimensions of the convolutional layer (3D, 2D or 1D), so I have to investigate a bit the internals in order to check if there's redundant padding inserted that we don't desire: for instance, with a kernel size of 2x2 and a stride of 1, I think we'd only need to insert an additional row and column, but maybe this tuple considers symmetric padding and inserts one row at the start and another at the end, and the same for the column, making the output size not match the input size.

Icemole avatar May 19 '23 11:05 Icemole

Anyway, padding="valid" is the easy case to fix, since then we only have to set padding = 0, as you said.

Yes, my question was just to confirm this. So we can just always do this then, no matter if ONNX or so. Just:

if padding == "valid":
    padding = 0

padding="same" is a little more tricky in my view: it involves not only calculating the padding for each dimension (which is the easy part) but also ordering it properly in the result.

We already have the code for adding the padding manual. See the code. It basically does:

if padding == "same" and <manual_padding_required>:
    ...  # add padding to input
    padding = "valid"

Nothing else needs to be done, or not?

albertz avatar May 19 '23 11:05 albertz

Btw, next step after the current demo works: Test our RF Conformer implementation.

albertz avatar May 19 '23 11:05 albertz

Actually, I just wonder now: Instead of manually adding the padding to the input, could we maybe just set padding = (filter_size - 1) // 2? But probably not really in all cases. E.g. only works with odd filter_size. No striding. Etc. But it might be worth to still have this case, as it might be faster than the manual padding.

albertz avatar May 19 '23 11:05 albertz

I was able to insert torch.onnx.is_in_export_mode() into the already developed code which creates manual striding. However, this only works for the RF demo: the PT demo as it is can't be executed, since it also uses padding="same" but we have no handler for pure PT code. Should we just let these details to the user, or also map "same" to a valid padding size?

Icemole avatar May 19 '23 13:05 Icemole

could we maybe just set padding = (filter_size - 1) // 2?

Internally, the forward() implementation of convolutional layers is as follows:

        if self.padding_mode != 'zeros':
            return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
                            weight, bias, self.stride,
                            _single(0), self.dilation, self.groups)

where self._reversed_padding_repeated_twice([i, j]) is [i, i, j, j]. This variable is actually calculated at initialization time. I think if we also do that and introduce an attribute like this one in our convolutional layers, we can save quite a bit of operations as we wouldn't have to calculate the same paddings for every forward, right?

Or maybe we could cache it somehow, or add an additional layer attribute self.padding_sizes at runtime?

Icemole avatar May 19 '23 13:05 Icemole

I was able to insert torch.onnx.is_in_export_mode() into the already developed code which creates manual striding. However, this only works for the RF demo: the PT demo as it is can't be executed, since it also uses padding="same" but we have no handler for pure PT code. Should we just let these details to the user, or also map "same" to a valid padding size?

Change padding="same" to padding=2 in the Torch demo. That should fix this issue.

albertz avatar May 19 '23 14:05 albertz

Internally, the forward() implementation of convolutional layers is as follows ...

This is only for the case padding_mode != 'zeros', which we don't support at all currently, right?

albertz avatar May 19 '23 14:05 albertz

Ah you're right, I confused padding with padding_mode. Still, aren't we calculating the same padding numbers for every forward()?

Icemole avatar May 19 '23 14:05 Icemole

I'm not exactly sure what you mean. The padding numbers don't need to be calculated at all. F.pad will just add them. F.pad will pad the input. The input is different for every forward() call of course, so there is nothing we can cache.

albertz avatar May 19 '23 14:05 albertz

Oh I see, yes, I neglected the effect of the dynamic dimensions on convolution, I was thinking more like image processing where the input is usually the same. Nevermind.

Icemole avatar May 19 '23 14:05 Icemole

I've pushed some commits that should fix some of the issues with the exporting tool. Hopefully I explained everything nicely.

With the current state of the tool (4785332), I can properly run both configs by the ONNX exporting tool. However, the graphs differ greatly. The number of operations in the RF graph is much bigger than the number of operations in the PT graph. At first sight, the PT graph looks much cleaner than the RF graph. I haven't studied the differences yet, I'll look at it on Monday. In the meantime, you can find both graphs below.

onnx_pt_graph.txt onnx_rf_graph.txt

Icemole avatar May 19 '23 15:05 Icemole

The number of operations in the RF graph is much bigger than the number of operations in the PT graph.

That's very interesting. Thanks for this. We should study this in detail. I already pushed a few things which should reduce this a bit. Some comments from a quick look:

  • A lot of it seems to be related to calculating the output sizes. This is sth which PyTorch doesn't do. Or at least not this simple demo. In practice, you might possibly have sth similar for pure PyTorch models, it's just that we don't do this in the simple demo. Let's not add this to the PT demo now. When comparing complexity, let's ignore this part for now. But still, we should look into it a bit whether it is reasonable.
  • There might be redundant computations. Maybe some output sizes are computed multiple times. Not sure.
  • Some constants like -9223372036854775807 ($-2^{63}+1$, min int64 + 1) look quite suspicious. I'm not sure if this is correct.

albertz avatar May 19 '23 20:05 albertz

I tried exporting again the RF demo model with the current changes, and the graph is much smaller now! Basically almost equal to the PyTorch graph: onnx_rf_graph_new.txt :partying_face:

Icemole avatar May 22 '23 07:05 Icemole

I tried exporting again the RF demo model with the current changes, and the graph is much smaller now! Basically almost equal to the PyTorch graph: onnx_rf_graph_new.txt 🥳

The actual computation is not just almost identical, but exactly identical, as far as I can see it.

In addition, it defines the output sizes. Which are just the same as the input sizes, because of padding "same", so it's just an identity function here.

I think when there are new calculations of dims, this will still be unnecessary complex. I think we can improve this more. But not sure how important this is.

albertz avatar May 22 '23 07:05 albertz

I think I reopen this to keep track of the remaining issues.

One remaining issue now is that we do not really check model_outputs, whether that matches what the user actually returns in the forward-step function. And in case of the PT demo, where the user does not provide the dims in the mark_as_output call, it does not exactly match the dims from model_outputs.

albertz avatar May 22 '23 07:05 albertz

Edit: you beat me to it :slightly_smiling_face:

There seems to be one last thing to figure out before the tool is fully operational. I get an error while running the PyTorch demo:

RuntimeError: number of output names provided (3) exceeded number of outputs (1)

This is the full stack_trace.

I think this is happening because somehow the RF model is outputting the output sizes as well, while the PT model isn't.

PT and RF don't exactly behave the same when calling rf.get_run_ctx().outputs.as_raw_tensor_dict(). PT instantiates the dimensions of rf.get_run_ctx().outputs.data to actual dimensions, while RF keeps the batch dim and the time dim as dynamic dimensions. Not sure if this change in behavior is intended.

Since in the PT demo we don't have access to batch/time dims, the output dimensions are instantiated to actual values. However, this isn't the case in the RF demo, where we have access to the shape of the original tensor as Dim classes. Since these dimensions are also marked as output, I think this is the essential difference.

I introduced a workaround in 47853321, which is why I was able to post the PT graph above, but it was reverted on bb40b4ca. Is there any other way to fix this?

Icemole avatar May 22 '23 07:05 Icemole

I would maybe extend init_forward_step_run_ctx to pass a model_outputs template. Then in mark_as_output, it would check for that and follow the logic I described above. Also in the end, some function check_outputs_match or so, which checks whether the all outputs are given.

albertz avatar May 22 '23 07:05 albertz

Another thing: This script export_to_onnx.py, this is actually PT specific, right? We should then rename it, to torch_export_to_onnx.py or so.

albertz avatar May 22 '23 07:05 albertz

some function check_outputs_match or so, which checks whether the all outputs are given

And if they don't match, which is what we're seeing here, what to do? We've detected the error, but should we fix it, or do we leave the user to it? If we fix it, do we give priority to the expected outputs, or to the actual outputs? I would argue that the actual outputs would have more importance, but it's an "unexpected" situation in which maybe an automatic fix isn't the most complete thing to do...

Icemole avatar May 22 '23 07:05 Icemole

This script export_to_onnx.py, this is actually PT specific, right? We should then rename it, to torch_export_to_onnx.py or so.

Yes, it works for the PT backend, independently of whether the RF frontend or actual PT code is being used. I'm fine with the name change.

Icemole avatar May 22 '23 07:05 Icemole

some function check_outputs_match or so, which checks whether the all outputs are given

And if they don't match, which is what we're seeing here, what to do?

Now, we are not seeing this here. If we do that as I described, there would also not be an error, as they would actually match then. But if they do not match, we would just throw an error. What we are seeing is an error from ONNX, which is not what I mean.

If we fix it, do we give priority to the expected outputs, or to the actual outputs?

It's just an error if the outputs do not match. We don't need to handle this further.

albertz avatar May 22 '23 07:05 albertz

There's one more item which I would probably add to the list of items to check and that we briefly discussed: running the resulting ONNX model in RASR.

Icemole avatar May 22 '23 09:05 Icemole

3c22503 addresses "Check model_outputs" in the main checklist.

Icemole avatar May 24 '23 08:05 Icemole

3c22503 addresses "Check model_outputs" in the main checklist.

There were some problems with that. See my updated commit.

In this commit, I also already overtake the dims in case they are not specified, so this should be implemented as well now.

albertz avatar May 24 '23 11:05 albertz

I saw the updated commit. The RF demo works, but the PT demo still doesn't work with a now-familiar error:

Exception: Dim{B}: need placeholder, self.dimension or self.dyn_size for dim value

Stack trace here. (Edit Meta: A Gist is better because you can directly see it.)

It seems that in the __call__ method of the wrapper module ForwardModulePT, the line return rf.get_run_ctx().outputs.as_raw_tensor_dict() can't convert to raw tensor dict because the batch/time dimension tags don't have proper values, which is what happened in the past when we didn't have the function that filled the tensor dimensions with random values.

If rf.init_forward_step_run_ctx is to receive a TensorDict, it should at least receive the filled dimensions of extern_data that were already filled in the call to tensor_dict_fill_random_numpy.

Getting the dims that coincide in name won't work, as they are different objects. Only the batch dim is the same...

Do you think something like the following would work?

  1. Declare the numpy seed
  2. Fill model_outputs
  3. Declare the same numpy seed
  4. Fill extern_data

If dynamic dimensions can change throughout the forward call, as it happens in for instance a convolutional layer when applying padding="valid", then this would probably not work...

If I understood the problem correctly, then the question would be: how can we get the output dim objects with the actual, proper values filled before we run the data through the model (i.e. at rf.init_run_ctx_forward() time)?

Icemole avatar May 24 '23 16:05 Icemole