ivy
ivy copied to clipboard
add svd to jax numpy linalg
Close #4702
It has already been done. See #4466
It has already been done. See #4466
Hi @karalleyna , the subtask comes from https://github.com/unifyai/ivy/issues/4614, not same as what you have done.
Ah, I'm really sorry. I didn't see your issue :(
Ah, sorry! you are welcome, we all contribute to the IVY.
Hi!
Thanks for the PR! I've had a look and what's there seems quite good, but it is missing a few of the arguments from the JAX function it replicates.
At https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.linalg.svd.html it says the function signature is jax.numpy.linalg.svd(a, full_matrices=True, compute_uv=True, hermitian=False)
which means that the function added to the Ivy JAX frontend must have the same signature, that is, the same name and parameters.
This may require some function composition in order to recreate the functionality when compute_uv=True
, but you don't need to worry about using the input value of hermitian
because it doesn't affect input-output behaviour. You'll still need to take in a value, so outside code can specify it without crashing, but the function doesn't need to be affected by its value. See pinv
in ivy/functional/frontends/numpy/linalg/solving_equations_and_inverting_matrices.py
for an example.
Any questions or comments feel free to let me know!
Hi, if you take a look at ivy/functional/backends/jax/linear_algebra.py you will find the difference between the backend function svd and the official doc
-
ivy/functional/backends/jax/linear_algebra.py svd
Hi, I see what you're saying with the backends. This difference occurs because the backends are there to implement Ivy API functions, e.g. ivy.svd(x, /, *, full_matrices=True)
. It only has to worry about 2 parameters because the Ivy API only specifies 2 parameters for this function.
However the frontends are used for translating programs written for other frameworks, in this case JAX, which may make use of any parameters specified in the original framework's docs. To make what I'm about to say clearer I've added an example, the frontend implementation of argmax
in the JAX frontend, found at ivy/functional/frontends/jax/lax/operators.py
.
We see that the underlying Ivy function only takes in 2 parameters, operand
and axis
, but the corresponding JAX function takes in one additional argument, index_dtype
. Therefore the frontend function argmax
must therefore also support the index_dtype
parameter to allow for this translation to Ivy code to be done more easily.
To return to your implementation of svd
for the JAX frontend, it must also take in the parameters compute_uv=True
and hermitian=False
, as those are legal parameters to be passed in by JAX programs. Additionally it must also be tested against both inputs.
However, other than this detail the PR looks very good and most importantly the test looks good, once the 2 new parameters are added I'll run the test and if it passes it should be good to be merged into Ivy.
Once again, any questions or comments feel free to let me know!
Hi @JamieLine , Done
Hi, @PatricYan sorry this has taken a while.
I've gone through and had a look, here is what I spotted.
- I think the changes to the Jax backend will need to be undone unfortunately as I believe we need all of the backends to have the same function signature (to ensure the code is stable regardless of which backend has been selected), but I do understand that this provides a partial solution to the problem.
- The frontend function should take the output from
ivy.svd
and change that if needed to match the output of the real Jaxsvd
function. I think that shouldn't be too difficult as the Ivy function appears to compute everything needed all the time, and it looks like it will just be a matter of picking a subset of the output of the Ivy function to return, depending on the inputs. - In the frontend, you don't need to pass
compute_uv
orhermitian
intoivy.svd
asivy.svd
doesn't accept these parameters, and therefore passing them causes a test failure. - I think you can simplify
_svd_get_dtype_and_data
by removing theret_shape=True
flag as you don't use the shape parameter it returns later on. - Try running the test locally to ensure the function doesn't crash, and then I can queue it to be tested on our CI to ensure it works on all backends.
- You might be able to use something along the lines of
shape=helpers.ints(min_value=2, max_value=20).map(lambda x: tuple([x, x])),
when generating test data inhelpers.dtype_and_values
to make it a bit sleeker.
The rest is good though, and again I'm very sorry I've taken so long to look through this.
This PR has been labelled as stale because it has been inactive for more than 7 days. If you would like to continue working on this PR, then please add another comment or this PR will be closed in 7 days.
This PR has been closed because it has been marked as stale for more than 7 days with no activity.
This PR has been labelled as stale because it has been inactive for more than 7 days. If you would like to continue working on this PR, then please add another comment or this PR will be closed in 7 days.