keras
keras copied to clipboard
add associative_scan
Addresses #19904, adds associative_scan support for all backends.
This is my first time contributing, so would greatly appreciate any feedback or something I may be missing (and would be more than happy to apply them)! Thanks!
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).
View this failed invocation of the CLA check for more information.
For the most up to date status, view the checks section at the bottom of the pull request.
Codecov Report
Attention: Patch coverage is 69.43005% with 59 lines in your changes missing coverage. Please review.
Project coverage is 72.52%. Comparing base (
0c69b3b) to head (efbce61).
:exclamation: There is a different number of reports uploaded between BASE (0c69b3b) and HEAD (efbce61). Click for more details.
HEAD has 2 uploads less than BASE
Flag BASE (0c69b3b) HEAD (efbce61) keras 4 3 keras-tensorflow 1 0
Additional details and impacted files
@@ Coverage Diff @@
## master #19938 +/- ##
==========================================
- Coverage 79.00% 72.52% -6.48%
==========================================
Files 499 499
Lines 46531 46724 +193
Branches 8561 8617 +56
==========================================
- Hits 36761 33888 -2873
- Misses 8039 11106 +3067
+ Partials 1731 1730 -1
| Flag | Coverage Δ | |
|---|---|---|
| keras | 72.43% <69.43%> (-6.43%) |
:arrow_down: |
| keras-jax | 62.26% <19.68%> (-0.18%) |
:arrow_down: |
| keras-numpy | 57.29% <44.04%> (-0.06%) |
:arrow_down: |
| keras-tensorflow | ? |
|
| keras-torch | 62.32% <44.04%> (-0.08%) |
:arrow_down: |
Flags with carried forward coverage won't be shown. Click here to find out more.
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
It seems there is a device placement issue with torch on GPU, could you take a look? https://btx.cloud.google.com/invocations/0b42aedc-c253-401a-b110-1f441959ccc8/targets/keras%2Fgithub%2Fubuntu%2Fgpu%2Ftorch%2Fcontinuous/log
It seems there is a device placement issue with torch on GPU, could you take a look? https://btx.cloud.google.com/invocations/0b42aedc-c253-401a-b110-1f441959ccc8/targets/keras%2Fgithub%2Fubuntu%2Fgpu%2Ftorch%2Fcontinuous/log
Thanks for the info! I made a PR with a fix: #19940
@SeKim12 thanks for this nice contribution!
Did you test using the pytorch backend? Are you able to compile the associative scan operator? I will be testing this myself on an isolated project, but wanted to get your feedback in case you already ran some benchmarks. Thanks!
@SeKim12 thanks for this nice contribution!
Did you test using the pytorch backend? Are you able to compile the associative scan operator? I will be testing this myself on an isolated project, but wanted to get your feedback in case you already ran some benchmarks. Thanks!
Hi @carlosluis! I was able to compile it and use it for the PyTorch backend. Here is a script that I used to quickly benchmark performance: https://gist.github.com/SeKim12/0b5a77fbb05c707e60dcee03cfd7c24b
Hi @SeKim12!
This PR imports optree directly (example), which breaks Google colab usage. We recommend using from keras.src import tree instead! Could you refactor parts of your code that uses optree to use Keras' tree instead?
cc @hertschuh @fchollet
Hi @SamanehSaadat,
My apologies -- I just saw this notification! It seems like the optree imports were handled by @fchollet. Thank you! Please let me know if there is anything else I can do!