burn icon indicating copy to clipboard operation
burn copied to clipboard

Fix ONNX Gather codegen for Shape input

Open hexd0t opened this issue 1 year ago • 1 comments

Pull Request Template

Checklist

  • [x] Confirmed that run-checks all script has been executed.
  • [x] Made sure the book is up to date with changes in this PR.

Related Issues/PRs

This finishes some work still left after #2128 for ONNX Gather nodes to accept Shape inputs.

Without it, the actual codegen fails (Note that the error message is misleading, the Arg here is not a scalar but a Shape(2), but the message predates Shapes being used in ONNX imports):

  ERROR burn_import::logger: PANIC => panicked at burn\crates\burn-import\src\onnx\to_burn.rs:1245:18:
  Can't transform scalar to tensor.

  --- stderr
  thread 'main' panicked at burn\crates\burn-import\src\onnx\to_burn.rs:1245:18:
  Can't transform scalar to tensor.
  stack backtrace:
     0: std::panicking::begin_panic_handler
               at /rustc/051478957371ee0084a7c0913941d2a8c4757bb9/library\std\src\panicking.rs:652
     1: core::panicking::panic_fmt
               at /rustc/051478957371ee0084a7c0913941d2a8c4757bb9/library\core\src\panicking.rs:72
     2: burn_import::onnx::to_burn::<impl core::convert::From<&onnx_ir::ir::Argument> for burn_import::burn::ty::TensorType>::from
     3: burn_import::onnx::to_burn::ParsedOnnxGraph::gather_conversion
     4: burn_import::onnx::to_burn::ParsedOnnxGraph::into_burn
     5: burn_import::onnx::to_burn::ModelGen::run_from_script
     6: burn_import::onnx::to_burn::ModelGen::run_from_script
     7: burn_import::onnx::to_burn::ModelGen::run_from_script
     8: build_script_build::burn::import_models
     9: build_script_build::burn::import_models
    10: build_script_build::burn::import_models
    11: core::ops::function::FnOnce::call_once
  note: Some details are omitted, run with `RUST_BACKTRACE=full` for a verbose backtrace.

It will probably also be needed for the model in issue #2115 to import successfully.

Changes

Check if the input value is a Shape instead of a Tensor, and if so, convert the Shape to a tensor before doing the select operation.

Testing

Added an ONNX-Test for this case, as well as a codegen unittest.

hexd0t avatar Aug 09 '24 18:08 hexd0t

Codecov Report

Attention: Patch coverage is 94.25287% with 5 lines in your changes missing coverage. Please review.

Project coverage is 86.11%. Comparing base (ff8d030) to head (da2c934). Report is 2 commits behind head on main.

Files Patch % Lines
crates/burn-import/src/burn/node/gather.rs 96.96% 2 Missing :warning:
crates/burn-import/src/onnx/op_configuration.rs 60.00% 2 Missing :warning:
crates/burn-import/src/onnx/to_burn.rs 50.00% 1 Missing :warning:
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2148      +/-   ##
==========================================
- Coverage   86.21%   86.11%   -0.10%     
==========================================
  Files         694      694              
  Lines       88854    88822      -32     
==========================================
- Hits        76606    76490     -116     
- Misses      12248    12332      +84     

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

codecov[bot] avatar Aug 09 '24 19:08 codecov[bot]

Just a minor comment for now.

For the new shape type introduced this LGTM.

I'd like to get @antimora's feedback though. I know you wanted to avoid simple shape operations to be done on GPU and retrieved to CPU, but with this typical sequence of tensor -> shape -> gather ops we have the shape results on CPU and the indices for gather as a tensor. Hence we have to either convert the shape back to a tensor (as per this PR) or get the indices as a vec... didn't we just come back full circle to your initial concern?

We spoke offline. It's true we want to use Shape as much as possible (only when dealing with shapes) but in this instance is unavoidable since select needs a multidimensional tensor and shape information comes from CPU (struct).

antimora avatar Aug 12 '24 15:08 antimora

It's true we want to use Shape as much as possible

I haven't implemented this here, but at least for the case where a Scalar index is selected from a Shape, a special path could be added to output the size as a Scalar or another Shape with just the one size completely on the CPU. At least from the (admittedly few) ONNX I've seen where Shape and Gather are combined, the index was usually a Scalar anyways.

Of course, adding this will probably expose even more Ops that cannot handle Shape inputs correctly yet, so implementing something like that would likely introduce regressions in terms of model compatibility at first.

hexd0t avatar Aug 14 '24 14:08 hexd0t