keras icon indicating copy to clipboard operation
keras copied to clipboard

add associative_scan

Open SeKim12 opened this issue 1 year ago • 2 comments

Addresses #19904, adds associative_scan support for all backends.

This is my first time contributing, so would greatly appreciate any feedback or something I may be missing (and would be more than happy to apply them)! Thanks!

SeKim12 avatar Jun 30 '24 01:06 SeKim12

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

google-cla[bot] avatar Jun 30 '24 01:06 google-cla[bot]

Codecov Report

Attention: Patch coverage is 69.43005% with 59 lines in your changes missing coverage. Please review.

Project coverage is 72.52%. Comparing base (0c69b3b) to head (efbce61).

Files Patch % Lines
keras/src/backend/tensorflow/core.py 7.14% 52 Missing :warning:
keras/src/backend/common/backend_utils.py 60.00% 1 Missing and 1 partial :warning:
keras/src/backend/numpy/core.py 96.00% 1 Missing and 1 partial :warning:
keras/src/backend/torch/core.py 96.29% 1 Missing and 1 partial :warning:
keras/api/_tf_keras/keras/ops/__init__.py 0.00% 1 Missing :warning:

:exclamation: There is a different number of reports uploaded between BASE (0c69b3b) and HEAD (efbce61). Click for more details.

HEAD has 2 uploads less than BASE
Flag BASE (0c69b3b) HEAD (efbce61)
keras 4 3
keras-tensorflow 1 0
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #19938      +/-   ##
==========================================
- Coverage   79.00%   72.52%   -6.48%     
==========================================
  Files         499      499              
  Lines       46531    46724     +193     
  Branches     8561     8617      +56     
==========================================
- Hits        36761    33888    -2873     
- Misses       8039    11106    +3067     
+ Partials     1731     1730       -1     
Flag Coverage Δ
keras 72.43% <69.43%> (-6.43%) :arrow_down:
keras-jax 62.26% <19.68%> (-0.18%) :arrow_down:
keras-numpy 57.29% <44.04%> (-0.06%) :arrow_down:
keras-tensorflow ?
keras-torch 62.32% <44.04%> (-0.08%) :arrow_down:

Flags with carried forward coverage won't be shown. Click here to find out more.

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

codecov-commenter avatar Jun 30 '24 01:06 codecov-commenter

It seems there is a device placement issue with torch on GPU, could you take a look? https://btx.cloud.google.com/invocations/0b42aedc-c253-401a-b110-1f441959ccc8/targets/keras%2Fgithub%2Fubuntu%2Fgpu%2Ftorch%2Fcontinuous/log

fchollet avatar Jun 30 '24 21:06 fchollet

It seems there is a device placement issue with torch on GPU, could you take a look? https://btx.cloud.google.com/invocations/0b42aedc-c253-401a-b110-1f441959ccc8/targets/keras%2Fgithub%2Fubuntu%2Fgpu%2Ftorch%2Fcontinuous/log

Thanks for the info! I made a PR with a fix: #19940

SeKim12 avatar Jun 30 '24 21:06 SeKim12

@SeKim12 thanks for this nice contribution!

Did you test using the pytorch backend? Are you able to compile the associative scan operator? I will be testing this myself on an isolated project, but wanted to get your feedback in case you already ran some benchmarks. Thanks!

carlosluis avatar Jul 09 '24 10:07 carlosluis

@SeKim12 thanks for this nice contribution!

Did you test using the pytorch backend? Are you able to compile the associative scan operator? I will be testing this myself on an isolated project, but wanted to get your feedback in case you already ran some benchmarks. Thanks!

Hi @carlosluis! I was able to compile it and use it for the PyTorch backend. Here is a script that I used to quickly benchmark performance: https://gist.github.com/SeKim12/0b5a77fbb05c707e60dcee03cfd7c24b

SeKim12 avatar Jul 09 '24 16:07 SeKim12

Hi @SeKim12!

This PR imports optree directly (example), which breaks Google colab usage. We recommend using from keras.src import tree instead! Could you refactor parts of your code that uses optree to use Keras' tree instead?

cc @hertschuh @fchollet

SamanehSaadat avatar Aug 08 '24 18:08 SamanehSaadat

Hi @SamanehSaadat,

My apologies -- I just saw this notification! It seems like the optree imports were handled by @fchollet. Thank you! Please let me know if there is anything else I can do!

SeKim12 avatar Aug 13 '24 05:08 SeKim12