keras icon indicating copy to clipboard operation
keras copied to clipboard

Add keras.ops.searchsorted

Open LarsKue opened this issue 1 year ago • 6 comments

This is commonly used for spline transformations.

LarsKue avatar Jun 26 '24 13:06 LarsKue

Codecov Report

Attention: Patch coverage is 72.50000% with 11 lines in your changes missing coverage. Please review.

Project coverage is 79.07%. Comparing base (558d38c) to head (607eb68). Report is 26 commits behind head on master.

Files Patch % Lines
keras/src/backend/jax/numpy.py 60.00% 1 Missing and 1 partial :warning:
keras/src/backend/numpy/numpy.py 60.00% 1 Missing and 1 partial :warning:
keras/src/backend/tensorflow/numpy.py 60.00% 1 Missing and 1 partial :warning:
keras/src/backend/torch/numpy.py 66.66% 1 Missing and 1 partial :warning:
keras/src/ops/numpy.py 88.23% 1 Missing and 1 partial :warning:
keras/api/_tf_keras/keras/ops/__init__.py 0.00% 1 Missing :warning:
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #19922      +/-   ##
==========================================
+ Coverage   79.02%   79.07%   +0.05%     
==========================================
  Files         499      499              
  Lines       46436    46771     +335     
  Branches     8548     8626      +78     
==========================================
+ Hits        36695    36984     +289     
- Misses       8015     8049      +34     
- Partials     1726     1738      +12     
Flag Coverage Δ
keras 78.93% <72.50%> (+0.05%) :arrow_up:
keras-jax 62.26% <55.00%> (-0.16%) :arrow_down:
keras-numpy 57.29% <60.00%> (+0.07%) :arrow_up:
keras-tensorflow 63.56% <57.50%> (-0.09%) :arrow_down:
keras-torch 62.32% <57.50%> (-0.06%) :arrow_down:

Flags with carried forward coverage won't be shown. Click here to find out more.

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

codecov-commenter avatar Jun 26 '24 13:06 codecov-commenter

Converted to draft because I will add a test

LarsKue avatar Jun 26 '24 13:06 LarsKue

I opted to try to maximize support for N-D searchsorted (because this is my use-case). However, numpy does not support it. JAX supports it by vmapping, which I implemented.

If you have better suggestions on how we can support N-D searchsorted, I would be happy to implement them.

The tests also need to be updated, still, because self.assertAllEqual does not support multi-dimensional tensors.

Nevertheless, I am marking this as ready for review now, so that you can give feedback. Thank you.

LarsKue avatar Jun 26 '24 14:06 LarsKue

Thanks for the PR!

I opted to try to maximize support for N-D searchsorted (because this is my use-case). However, numpy does not support it. JAX supports it by vmapping, which I implemented.

Since we support vmapping APIs, we could simply not implement N-D support for this op. We should try to stay as close to NumPy as possible in order to minimize user surprise.

A TF test seems to be failing:

>           assert knp.all(knp.searchsorted(a, v) == expected)

keras/src/ops/numpy_test.py:3990: 
...
E     tensorflow.python.framework.errors_impl.InvalidArgumentError:
cannot compute Equal as input #1(zero-based) was expected to be a int32 tensor but is a int64 tensor

fchollet avatar Jun 26 '24 14:06 fchollet

@fchollet Thank you for the review!

we could simply not implement N-D support for this op.

You make a good point. In that case, should we raise an error if the user passes an N-D sorted_sequence, or let the backend handle it if it is incompatible? Raising would make the function truly agnostic, but prevent using built-in N-D functionality for backends that do support it.

If we raise an error: Should we do this in keras.ops or in the respective keras.backend functions?

A TF test seems to be failing

We can drop the part that is failing if we only support 1-D.

LarsKue avatar Jun 27 '24 09:06 LarsKue

In that case, should we raise an error if the user passes an N-D sorted_sequence, or let the backend handle it if it is incompatible?

Better to do it in each backend function I think!

fchollet avatar Jun 27 '24 17:06 fchollet

@fchollet I implemented your requested changes, raising the error in each backend function, respectively. However, I currently do not know of a way to raise the error appropriately in the SearchSorted.symbolic_call method. This is for two reasons:

  1. ndim returns different output for symbolic vs. non-symbolic tensors
  2. symbolic_call is called with both symbolic and non-symbolic input (see e.g. the NumpyDtypeTest tests)

I could add a case-distinction in symbolic_call using any_symbolic_tensors, but no other op does this. What is the recommened way to deal with this kind of issue?

LarsKue avatar Jul 03 '24 10:07 LarsKue

Rather than using the function ndim() you could use len(shape(x)), which works the same everywhere?

fchollet avatar Jul 04 '24 03:07 fchollet