burn
burn copied to clipboard
Onnx op topk
Pull Request Template
Checklist
- [x] Confirmed that
run-checks all
script has been executed.- some failed I think due to my computer. I.e
test_rotary_encoding_forward
failed but passed when I ran it via the ui
- some failed I think due to my computer. I.e
- [x] Made sure the book is up to date with changes in this PR.
Related Issues/PRs
#1714
Provide links to relevant issues and dependent PRs. https://github.com/tracel-ai/burn/issues/1714
Changes
- added support for onnx topk op. Relies on version 1 of topk node. Other versions take K as an input instead of an attribute, ref
- made
fn tanh_should_not_have_numerical_bugs_on_macos()
only run on macos - updated
IOEntry::Node
in theinput_names_map
. This was important as before it was never incremented with the number of outputs so you would always have _1 as the output name suffix even if there were x >= 2 outputs.
Summarize the problem being addressed and your solution.
Testing
ran the below:
cargo test
./run-checks.sh all
Describe how these changes have been tested. instructions listed here and here