ivy icon indicating copy to clipboard operation
ivy copied to clipboard

add svd to jax numpy linalg

Open PatricYan opened this issue 2 years ago • 9 comments

Close #4702

PatricYan avatar Sep 20 '22 13:09 PatricYan

It has already been done. See #4466

karalleyna avatar Sep 20 '22 14:09 karalleyna

It has already been done. See #4466

Hi @karalleyna , the subtask comes from https://github.com/unifyai/ivy/issues/4614, not same as what you have done.

PatricYan avatar Sep 20 '22 14:09 PatricYan

Ah, I'm really sorry. I didn't see your issue :(

karalleyna avatar Sep 20 '22 14:09 karalleyna

Ah, sorry! you are welcome, we all contribute to the IVY.

PatricYan avatar Sep 20 '22 14:09 PatricYan

Hi!

Thanks for the PR! I've had a look and what's there seems quite good, but it is missing a few of the arguments from the JAX function it replicates.

At https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.linalg.svd.html it says the function signature is jax.numpy.linalg.svd(a, full_matrices=True, compute_uv=True, hermitian=False) which means that the function added to the Ivy JAX frontend must have the same signature, that is, the same name and parameters.

This may require some function composition in order to recreate the functionality when compute_uv=True, but you don't need to worry about using the input value of hermitian because it doesn't affect input-output behaviour. You'll still need to take in a value, so outside code can specify it without crashing, but the function doesn't need to be affected by its value. See pinv in ivy/functional/frontends/numpy/linalg/solving_equations_and_inverting_matrices.py for an example.

Any questions or comments feel free to let me know!

JamieLine avatar Sep 22 '22 01:09 JamieLine

Hi, if you take a look at ivy/functional/backends/jax/linear_algebra.py you will find the difference between the backend function svd and the official doc

  • ivy/functional/backends/jax/linear_algebra.py svd image

  • official doc image

PatricYan avatar Sep 22 '22 02:09 PatricYan

Hi, I see what you're saying with the backends. This difference occurs because the backends are there to implement Ivy API functions, e.g. ivy.svd(x, /, *, full_matrices=True). It only has to worry about 2 parameters because the Ivy API only specifies 2 parameters for this function.

However the frontends are used for translating programs written for other frameworks, in this case JAX, which may make use of any parameters specified in the original framework's docs. To make what I'm about to say clearer I've added an example, the frontend implementation of argmax in the JAX frontend, found at ivy/functional/frontends/jax/lax/operators.py.

image

We see that the underlying Ivy function only takes in 2 parameters, operand and axis, but the corresponding JAX function takes in one additional argument, index_dtype. Therefore the frontend function argmax must therefore also support the index_dtype parameter to allow for this translation to Ivy code to be done more easily.

To return to your implementation of svd for the JAX frontend, it must also take in the parameters compute_uv=True and hermitian=False, as those are legal parameters to be passed in by JAX programs. Additionally it must also be tested against both inputs.

However, other than this detail the PR looks very good and most importantly the test looks good, once the 2 new parameters are added I'll run the test and if it passes it should be good to be merged into Ivy.

Once again, any questions or comments feel free to let me know!

JamieLine avatar Sep 22 '22 13:09 JamieLine

Hi @JamieLine , Done

PatricYan avatar Sep 23 '22 08:09 PatricYan

Hi, @PatricYan sorry this has taken a while.

I've gone through and had a look, here is what I spotted.

  • I think the changes to the Jax backend will need to be undone unfortunately as I believe we need all of the backends to have the same function signature (to ensure the code is stable regardless of which backend has been selected), but I do understand that this provides a partial solution to the problem.
  • The frontend function should take the output fromivy.svd and change that if needed to match the output of the real Jax svd function. I think that shouldn't be too difficult as the Ivy function appears to compute everything needed all the time, and it looks like it will just be a matter of picking a subset of the output of the Ivy function to return, depending on the inputs.
  • In the frontend, you don't need to pass compute_uv or hermitian into ivy.svd as ivy.svd doesn't accept these parameters, and therefore passing them causes a test failure.
  • I think you can simplify _svd_get_dtype_and_data by removing the ret_shape=True flag as you don't use the shape parameter it returns later on.
  • Try running the test locally to ensure the function doesn't crash, and then I can queue it to be tested on our CI to ensure it works on all backends.
  • You might be able to use something along the lines of shape=helpers.ints(min_value=2, max_value=20).map(lambda x: tuple([x, x])), when generating test data in helpers.dtype_and_values to make it a bit sleeker.

The rest is good though, and again I'm very sorry I've taken so long to look through this.

JamieLine avatar Sep 26 '22 22:09 JamieLine

This PR has been labelled as stale because it has been inactive for more than 7 days. If you would like to continue working on this PR, then please add another comment or this PR will be closed in 7 days.

ivy-seed avatar Nov 08 '22 06:11 ivy-seed

This PR has been closed because it has been marked as stale for more than 7 days with no activity.

ivy-seed avatar Nov 16 '22 06:11 ivy-seed

This PR has been labelled as stale because it has been inactive for more than 7 days. If you would like to continue working on this PR, then please add another comment or this PR will be closed in 7 days.

ivy-seed avatar Nov 16 '22 06:11 ivy-seed