burn
burn copied to clipboard
Support switching burn-wgpu between WebGPU implementations (wgpu <-> Dawn)
Checklist
- [ ] Confirmed that
run-checks all
script has been executed. - [ ] Made sure the book is up to date with changes in this PR.
Related Issues/PRs
Split from: https://github.com/tracel-ai/burn/pull/1475
Changes
Currently Burn can only be built against the wgpu WebGPU implementation. I found the ability to switch to Dawn useful, e.g. to get float16 support, which helps when trying to run Llama2 on top of Burn.
I've:
- split out any uses of wgpu into a separate module, and defined traits which provide all operations needed by burn-wgpu,
- added Dawn as a submodule,
- modified the build system to be able to build, generate bindings for, and link against Dawn,
- added a module which wraps the usage of the Dawn bindings to make it look more like wgpu.
Testing
I ran unit tests for the burn-wgpu
crate. Dawn test cases should be generated when the dawn
feature is enabled for burn-wgpu
, some of them will fail because Dawn performs out of bound accesses if no explicit bounds checks are emitted for I/O arrays (I've pushed a branch with a fix for that here - https://github.com/p1-0tr/burn/tree/ps-allow-using-dawn-and-wgpu-w-bounds).
Codecov Report
Attention: Patch coverage is 78.02198%
with 80 lines
in your changes missing coverage. Please review.
Project coverage is 86.40%. Comparing base (
5bbc5ea
) to head (ebd6963
). Report is 67 commits behind head on main.
:exclamation: Current head ebd6963 differs from pull request most recent head 3212b1a
Please upload reports for the commit 3212b1a to get more accurate results.
Files | Patch % | Lines |
---|---|---|
crates/burn-wgpu/src/compute/wgpu_api_shim.rs | 75.58% | 73 Missing :warning: |
crates/burn-wgpu/src/runtime.rs | 68.18% | 7 Missing :warning: |
Additional details and impacted files
@@ Coverage Diff @@
## main #1583 +/- ##
==========================================
- Coverage 86.61% 86.40% -0.21%
==========================================
Files 700 695 -5
Lines 83427 80618 -2809
==========================================
- Hits 72257 69656 -2601
+ Misses 11170 10962 -208
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
@p1-0tr Just an update:
As you may have seen, we updated quite a lot in the wgpu server, which solved a lot of issues with different graphics APIs and improved performance a bit. We also changed how the element types are handled in the JIT backend, which paves the way for quantization and our new GPU Rust API. We still want to go further with this PR and eventually merge it. Things are a bit more stable now, so we can better review and integrate it.
Note: I don't think using a git submodule is optimal; I would prefer downloading it with a tagged version of Dawn to the cache directory using this method, however this could be done in a following PR.
As you may have seen, we updated quite a lot in the wgpu server, which solved a lot of issues with different graphics APIs and improved performance a bit. We also changed how the element types are handled in the JIT backend, which paves the way for quantization and our new GPU Rust API. We still want to go further with this PR and eventually merge it. Things are a bit more stable now, so we can better review and integrate it.
@nathanielsimard thanks for the update. I've noticed the changes when doing my last rebase :D
Note: I don't think using a git submodule is optimal; I would prefer downloading it with a tagged version of Dawn to the cache directory using this method, however this could be done in a following PR.
Thanks :) I'll look into that.
This PR has been marked as stale because it has not been updated for over a month
This PR has been marked as stale because it has not been updated for over a month
This belongs to cubecl now 😅 so I'll close due to the massive changes in this area.