Fix ONNX export of RAFT optical flow model
🚀 The feature
Two proposed changes to fix the export of the RAFT model with dynamic batch_size and dynamic num_flow_updates.
Changes to be done for dynamic batch_size:
in CorrBlock._compute_corr_volume:
change corr / torch.sqrt(torch.tensor(num_channels)) to corr / torch.sqrt(torch.tensor(num_channels).float())
Changes to be done for dynamic num_flow_updates: in RAFT.forward: change:
flow_predictions = []
for _ in range(num_flow_updates):
coords1 = coords1.detach() # Don't backpropagate gradients through this branch, see paper
corr_features = self.corr_block.index_pyramid(centroids_coords=coords1)
flow = coords1 - coords0
hidden_state, delta_flow = self.update_block(hidden_state, context, corr_features, flow)
coords1 = coords1 + delta_flow
up_mask = None if self.mask_predictor is None else self.mask_predictor(hidden_state)
upsampled_flow = upsample_flow(flow=(coords1 - coords0), up_mask=up_mask)
flow_predictions.append(upsampled_flow)
to
flow_predictions = torch.zeros((num_flow_updates, batch_size, 2, h, w))
for i in range(num_flow_updates):
coords1 = coords1.detach() # Don't backpropagate gradients through this branch, see paper
corr_features = self.corr_block.index_pyramid(centroids_coords=coords1)
flow = coords1 - coords0
hidden_state, delta_flow = self.update_block(hidden_state, context, corr_features, flow)
coords1 = coords1 + delta_flow
up_mask = None if self.mask_predictor is None else self.mask_predictor(hidden_state)
upsampled_flow = upsample_flow(flow=(coords1 - coords0), up_mask=up_mask)
flow_predictions[i] = upsampled_flow
Thanks for all your work :)
Motivation, pitch
Being able to use the model in onnxruntime with dynamic inputs.
Alternatives
No response
Additional context
No response
Thanks for the report/request @farresti . By any chance, is there a more pythonic way to still enable onnx export here? We try to be reasonably conservative about such changes that make the code less readable, as we have to balance a lot of targets in our code already (torchscript + mypy + onnx + torch.compile + [whatever comes next], all of which require ugly workarounds that quickly add up and lead to a massive maintenance burden).
Could specify which part would you like to be more pythonic please @NicolasHug? Is it the num_flow_updates, where the append is changed to a preallocation of the tensor? Unfortunately, without this preallocation, the onnx tracer does not register the num_flow_updates parameter and it does not end up in the exported graph input and stay constant. I tried to let the first implementation as it was (with append), and stacking it into a tensor afterward but the tracer does not see the link with the parameter and still export it as constant.
I based my proposition from this answer of stackoverflow: https://stackoverflow.com/a/76134353
Could specify which part would you like to be more pythonic please @NicolasHug?
I was referring to
flow_predictions = torch.zeros((num_flow_updates, batch_size, 2, h, w))
flow_predictions[i] = upsampled_flow
But if there's no decent alternative then that's OK, thanks for trying.
Would you like to submit a PR with those changes? Ideally with a short non-regression test so we can be sure to not inadvertently "clean" that part later.