burn
burn copied to clipboard
Add cumulative sum tensor operation
Pull Request Template
Starting a draft PR to align on a few things with maintainers as I dive into this.
Context: Per this convo, I wanted to add a cumulative product operation to burn.
My plan is to start with a cumulative sum operation. Then cumulative product can be developed using cumulative sum, log, and exp.
@nathanielsimard, Items to align on upfront:
- The name of the function should be
cumsum_dim
.cumsum
aligns with the pytorch api. In burn, operations that take an explicit dim argument seem to have a_dim
suffix. Alternatively we could remove the suffix. - The function should be implemented for Float and Int tensorkinds, but not bool.
- Backends: While the implementations for tch, candle, and ndarray seem straightforward, I have questions about jit. For jit, I cannot find a existing WGSL implementation for cumulative sum. Is the right approach in this situation to create a WGSL compute shader for it? Cumulative sum is probably hard to do well GPUs given the dependencies between elements, but I'm open to trying. I'm new to WGSL.
Checklist
- [ ] Confirmed that
run-checks all
script has been executed. - [ ] Made sure the book is up to date with changes in this PR.
Related Issues/PRs
Provide links to relevant issues and dependent PRs.
Changes
Summarize the problem being addressed and your solution.
Testing
Describe how these changes have been tested.
Hi @allenqm Although you tagged Nathaniel, I think I can answer in his stead:
-
It's great that you made a default implementation for cumprod using log and exp 👍
-
For the naming, I think we can leave the
_dim
out because there cannot be a version for all dims at once. That's what we did for instance with sort, which acts on a dimension but for which no "global sort" exists. -
Indeed, it must be implemented for int and float only, I saw you put it in
numeric
, which is the way to go 👍 -
For JIT, it is going to be a pain at the moment. There is no WGSL code anymore, the WGSL is always auto-generated from the intermediate JIT representation (kernels using the
gpu!
macro). It's honestly a pain to work with, it's not designed for new contributors to learn. I'm working on a language to rewrite them in an accessible way, see #1665, but it's not ready. -
For the GPU algorithm, you're right that the dependancy between elements will make it difficult to have an efficient kernel. The straightforward way would be to spawn one thread for the whole dim to sum, and this thread fills all the output spots while accumulating the inputs in a local sum variable. But for large dim to sum it can be slow. Not sure if there's better solutions, I haven't done any research.
I'm willing to write that kernel in the JIT intermediate representation if you want, so the operation becomes available soon; then we can optimize it later and with the upcoming language.
@louisfd Thanks so much for the guidance.
I will remove _dim suffix.
Thanks for offering to step in and write the kernel in the JIT intermediate representation. I'll take you up on that.
I'm going to try and get the tch, candle, ndarray, and autodiff implementations done by EoD tomorrow.
Just to be clear: I haven't written anything specific for cumprod yet. I was proposing that if we implement cumsum, then cumprod will be more straightforward as it could be described without new backend implementations (with the exception of autodiff), using the existing implementations of cumsum, exp, and log. Let me know if my assessment here seems off.
tch, candle, ndarray, autodiff + tests, and tensor tests have been added. Going to work on the onnx section of the contributor book next.
no action needed, just fyi @louisfd
This PR has been marked as stale because it has not been updated for over a month
Sorry for not flipping this to "Ready for Review" @louisfd . I think I've got the required onnx files in place. Can you take a look?
This PR has been marked as stale because it has not been updated for over a month
Labeled this PR as needs help completing it. It would be a great feature to have in Burn.
Sorry for the delay. I am planning to devote time this Friday to make the requested changes.
On Sun, Aug 4, 2024 at 3:34 PM Dilshod Tadjibaev @.***> wrote:
Labeled this PR as needs help completing it. It would be a great feature to have in Burn.
— Reply to this email directly, view it on GitHub https://github.com/tracel-ai/burn/pull/1722#issuecomment-2267645687, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABCYGJFVI5ON7SG3X23M4FLZPZ62TAVCNFSM6AAAAABHGF7VYWVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDENRXGY2DKNRYG4 . You are receiving this because you were mentioned.Message ID: @.***>
This PR has been marked as stale because it has not been updated for over a month