burn
burn copied to clipboard
Fix ONNX Gather codegen for Shape input
Pull Request Template
Checklist
- [x] Confirmed that
run-checks allscript has been executed. - [x] Made sure the book is up to date with changes in this PR.
Related Issues/PRs
This finishes some work still left after #2128 for ONNX Gather nodes to accept Shape inputs.
Without it, the actual codegen fails (Note that the error message is misleading, the Arg here is not a scalar but a Shape(2), but the message predates Shapes being used in ONNX imports):
ERROR burn_import::logger: PANIC => panicked at burn\crates\burn-import\src\onnx\to_burn.rs:1245:18:
Can't transform scalar to tensor.
--- stderr
thread 'main' panicked at burn\crates\burn-import\src\onnx\to_burn.rs:1245:18:
Can't transform scalar to tensor.
stack backtrace:
0: std::panicking::begin_panic_handler
at /rustc/051478957371ee0084a7c0913941d2a8c4757bb9/library\std\src\panicking.rs:652
1: core::panicking::panic_fmt
at /rustc/051478957371ee0084a7c0913941d2a8c4757bb9/library\core\src\panicking.rs:72
2: burn_import::onnx::to_burn::<impl core::convert::From<&onnx_ir::ir::Argument> for burn_import::burn::ty::TensorType>::from
3: burn_import::onnx::to_burn::ParsedOnnxGraph::gather_conversion
4: burn_import::onnx::to_burn::ParsedOnnxGraph::into_burn
5: burn_import::onnx::to_burn::ModelGen::run_from_script
6: burn_import::onnx::to_burn::ModelGen::run_from_script
7: burn_import::onnx::to_burn::ModelGen::run_from_script
8: build_script_build::burn::import_models
9: build_script_build::burn::import_models
10: build_script_build::burn::import_models
11: core::ops::function::FnOnce::call_once
note: Some details are omitted, run with `RUST_BACKTRACE=full` for a verbose backtrace.
It will probably also be needed for the model in issue #2115 to import successfully.
Changes
Check if the input value is a Shape instead of a Tensor, and if so, convert the Shape to a tensor before doing the select operation.
Testing
Added an ONNX-Test for this case, as well as a codegen unittest.
Codecov Report
Attention: Patch coverage is 94.25287% with 5 lines in your changes missing coverage. Please review.
Project coverage is 86.11%. Comparing base (
ff8d030) to head (da2c934). Report is 2 commits behind head on main.
Additional details and impacted files
@@ Coverage Diff @@
## main #2148 +/- ##
==========================================
- Coverage 86.21% 86.11% -0.10%
==========================================
Files 694 694
Lines 88854 88822 -32
==========================================
- Hits 76606 76490 -116
- Misses 12248 12332 +84
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
Just a minor comment for now.
For the new shape type introduced this LGTM.
I'd like to get @antimora's feedback though. I know you wanted to avoid simple shape operations to be done on GPU and retrieved to CPU, but with this typical sequence of tensor -> shape -> gather ops we have the shape results on CPU and the indices for gather as a tensor. Hence we have to either convert the shape back to a tensor (as per this PR) or get the indices as a vec... didn't we just come back full circle to your initial concern?
We spoke offline. It's true we want to use Shape as much as possible (only when dealing with shapes) but in this instance is unavoidable since select needs a multidimensional tensor and shape information comes from CPU (struct).
It's true we want to use Shape as much as possible
I haven't implemented this here, but at least for the case where a Scalar index is selected from a Shape, a special path could be added to output the size as a Scalar or another Shape with just the one size completely on the CPU. At least from the (admittedly few) ONNX I've seen where Shape and Gather are combined, the index was usually a Scalar anyways.
Of course, adding this will probably expose even more Ops that cannot handle Shape inputs correctly yet, so implementing something like that would likely introduce regressions in terms of model compatibility at first.