jax icon indicating copy to clipboard operation
jax copied to clipboard

added solve_sylvester and accompanying tests

Open Dekermanjian opened this issue 6 months ago • 29 comments

Hello, JAX team!

I am new to working with JAX and would appreciate feedback (and guidance) on my implementation of a solver for the Sylvester equation with accompanying tests.

This is related to #6089 #669 and would be pre-requisite to #19109

Thanks in advance!

Dekermanjian avatar May 17 '25 15:05 Dekermanjian

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

google-cla[bot] avatar May 17 '25 15:05 google-cla[bot]

Thank you, @jakevdp for your constructive feedback! I will make changes and modify the implementation in accordance with your suggestions.

Dekermanjian avatar May 20 '25 10:05 Dekermanjian

Hi @jakevdp, I made the changes that you suggested above. I have a quick question, in the tests after inheriting from jtu.JaxTestCase it seems like the numpy_dtype_promotion is set to strict. This was causing my tests to fail. Inside my tests, I set numpy_dtype_promotion=standard which is most likely incorrect. Could you point me to some documentation about this issue so that I can fix it?

Dekermanjian avatar May 21 '25 23:05 Dekermanjian

Hey @jakevdp, thank you for your patience with this PR.

I made the changes that you suggested above. There is one part that I wasn't sure about. The solve_sylvester() function loses some precision, especially for large matrices, when using float32. So in the tests I conditionally reduce the tolerances when the dtype is float32. Please let me know if this is not the way to handle this issue and I will change it.

Dekermanjian avatar May 25 '25 11:05 Dekermanjian

Hey @jakevdp, thank you for your patience with this PR. I made most of the changes you suggested above. I left replies on the suggestions where I took a different approach above. Let me know if my different approach is incorrect and I will go back change it.

Dekermanjian avatar May 29 '25 10:05 Dekermanjian

Hey @jakevdp, I made the above changes and modified solve_sylvester() to output in the dtype that matches the input dtype. Allowing for complex or float outputs.

Dekermanjian avatar May 29 '25 22:05 Dekermanjian

On h100 the large test on float32 is failing with an even larger discrepancy:

Mismatched elements: 8895 / 30000 (29.6%)
Max absolute difference among violations: 1.5408611
Max relative difference among violations: 18452.156
 ACTUAL: array([[ 1.800829,  2.928643, -0.344496, ...,  4.774832, -2.07168 ,
        -5.008069],
       [ 1.854761, -3.498744, -1.515197, ...,  0.164751,  2.064151,...
 DESIRED: array([[ 1.785693,  2.994908, -0.187475, ...,  4.250907, -2.186043,
        -4.966815],
       [ 1.842348, -2.990138, -0.931781, ...,  0.104629,  2.06976 ,...

I'm not sure the best approach here – is there some fundamental issue with the algorithm used here?

jakevdp avatar May 31 '25 19:05 jakevdp

On h100 the large test on float32 is failing with an even larger discrepancy:

Mismatched elements: 8895 / 30000 (29.6%)
Max absolute difference among violations: 1.5408611
Max relative difference among violations: 18452.156
 ACTUAL: array([[ 1.800829,  2.928643, -0.344496, ...,  4.774832, -2.07168 ,
        -5.008069],
       [ 1.854761, -3.498744, -1.515197, ...,  0.164751,  2.064151,...
 DESIRED: array([[ 1.785693,  2.994908, -0.187475, ...,  4.250907, -2.186043,
        -4.966815],
       [ 1.842348, -2.990138, -0.931781, ...,  0.104629,  2.06976 ,...

I'm not sure the best approach here – is there some fundamental issue with the algorithm used here?

Hey @jakevdp, there is a difference between the algorithm here and the one that scipy uses. SciPy first does a Schur decomposition of the matrices and solves a simpler solution before recasting the matrices back to the original basis. I had a working implementation of this but it was a little bit slow because of some nested for loops.

Would you give me a couple of days to try and take that solution and implement it using JAX scans? I think that it will be worth it for the algorithm to be more robust.

Dekermanjian avatar May 31 '25 22:05 Dekermanjian

Hey @jakevdp, I was able to implement the same algorithm that SciPy uses to solve the sylvester equation (Bartel-Stewart algorithm).

I believe that it now produces more similar results to SciPy (when you force scipy to use float32 instead of casting to float64). However, in my testing I found that when the A and B matrices have eigenvalues that sum to values close to zero the results become poor, especially for float32. With float64 depending on how close to zero the sums are you may still get pretty good approximations (the above also applies to SciPy) and because of that I added in a tol argument in the solve_sylvester function that will allow the user to specify how close to zero the sums of the eigenvalues can be before returning a matrix of NaNs.

Now because this affects float32 more drastically than float64, I wonder if it warrants adding back the warning to the docstring?

Dekermanjian avatar Jun 02 '25 11:06 Dekermanjian

Hmm, a scan with a length equal to the size of the matrix is going to be very slow, especially on accelerators where each scan iteration requires a kernel launch. The previous closed-form solution seemed to work in most cases – can we go back to that with some sort of qualification about its accuracy?

jakevdp avatar Jun 02 '25 16:06 jakevdp

Yes, I can certainly go back to the closed form solution and include the same tol argument. Just note that with that the Schur decomposition method it handles cases where the sum of the eigenvalues are nearly 0 better than the eigenvalue decomposition method. Would you be opposed to having both methods available and leaving it up to the user to select which one they would rather use?

Dekermanjian avatar Jun 02 '25 16:06 Dekermanjian

Would you be opposed to having both methods available and leaving it up to the user to select which one they would rather use?

That would be a great solution I think. We'd have to make sure to provide enough documentation to let the user make an informed choice of the setting.

jakevdp avatar Jun 02 '25 17:06 jakevdp

Okay, I will go ahead and bring back the eigenvalue decomposition method and I will beef up the docstrings to reflect what each method is doing behind the scenes.

Dekermanjian avatar Jun 02 '25 17:06 Dekermanjian

Hey @jakevdp, I am experiencing something weird when testing. I am getting unexpected results only during tests.

Below you can see the results of pytest -n auto tests/linalg_test.py -k test_no_solution_sylvester_eigen:

Screenshot 2025-06-03 at 4 46 16 PM

So I take the generated A, B, and C matrices and run them through the exact same function outside of the testing environment and this results in what I expect it to result in; a matrix of NaNs.

Screenshot 2025-06-03 at 4 47 34 PM

Am I missing something obvious here?

Dekermanjian avatar Jun 03 '25 22:06 Dekermanjian

Are you reconstructing the matrices from their printed representation? That may not result in bitwise-identical floating point values.

jakevdp avatar Jun 03 '25 22:06 jakevdp

That is a good point let me try to serialize them to disk and load them that way.

Dekermanjian avatar Jun 03 '25 23:06 Dekermanjian

Okay, @jakevdp I saved the matrices to disk using np.save() and loaded them outside the testing environment with np.load() and as expected the solver returns a matrix of NaNs only outside the testing environment. However during the test I am getting odd results. Any idea what might be happening?

Dekermanjian avatar Jun 03 '25 23:06 Dekermanjian

Is it possible that it's a float64 precision issue? Do you have the same X64 flag set in both contexts?

jakevdp avatar Jun 03 '25 23:06 jakevdp

Yeah, I believe I do. I haven't enabled float64 in both settings. The matrices of the failing tests are when testing complex64 and I checked the dtype of the matrices after np.load() and can confirm that they were loaded with complex64 dtype.

Dekermanjian avatar Jun 03 '25 23:06 Dekermanjian

Hey I might have found the issue. On the other environment I am running JAX version 0.6.1. While on the testing environment I am running JAX 0.6.1.dev20250602+f067b6586. I can confirm that on that JAX version I am not getting the expected results whereas on 0.6.1 I am. Any idea why that could be?

Dekermanjian avatar Jun 03 '25 23:06 Dekermanjian

Okay, @jakevdp after digging into the two versions of JAX I found that the results of jnp.linalg.eig() in the dev version is returning different results to JAX 0.6.1 and that the results from JAX 0.6.1 compared to the results from np.linalg.eig() match more closely than those returned by the JAX dev version. Moreover, since the matrices A, B and C are simulated to have sum zero eigenvalues the results from JAX 0.6.1 and numpy are correct as they recover this property while the dev version of JAX does not.

I tried looking at recent commits that may have changed jnp.linalg.eig() but I could not really pin point where the issue arises. Any ideas here?

Dekermanjian avatar Jun 04 '25 11:06 Dekermanjian

I don't think eig has changed between 0.6.1 and the current version. I suspect that somehow the jaxlib you're using in each environment is targeting a different LAPACK under the hood, and this is leading to different numerics between the environments.

jakevdp avatar Jun 04 '25 16:06 jakevdp

@jakevdp I checked the version on jaxlib and in my dev environment I was running jaxlib=0.6.0 so I updated it to version 0.6.1 to match the environment that is returning the expected jnp.linalg.eig() results but it is still returning incorrect values. Is there a way for me to verify what LAPACK version is being called by jnp.linalg.eig()?

Dekermanjian avatar Jun 04 '25 21:06 Dekermanjian

I also printed out jax.config.values and checked all the values across both environments and they were all the same.

Dekermanjian avatar Jun 04 '25 21:06 Dekermanjian

A quick update on the jnp.linalg.eig() issue. It seems that something in my dev environment is making jnp.linalg.eig() produce results that are exactly the same as scipy.linalg.eig() and after researching this more it is just as you hypothesized that the difference between numpy's and scipy's results come down to the LAPACK that they are calling.

Dekermanjian avatar Jun 04 '25 23:06 Dekermanjian

Hi @jakevdp, is there anything you'd like me to add/remove/modify with the latest commit?

Dekermanjian avatar Jun 11 '25 23:06 Dekermanjian

Hi - it wasn't clear to me whether you'd solved the eig issue. Is that resolved now?

jakevdp avatar Jun 11 '25 23:06 jakevdp

Sorry, about that! I should have communicated better. In the last commit I brought back the Eigen decomposition method as another option with a tolerance argument for letting the user decide how to handle ill-conditioned matrices. I added to the docstrings of both methods (Eigen and Bartel-Stewart) descriptions and the strengths and weaknesses of each method.

For the issue with the Eigen where different environments are returning different results. One environment is returning the same results as numpy.linalg while the other is the same as scipy.linalg, the differences arise due to calling different LAPACK's and the differences are specific to low-precision dtypes and ill-conditioned matrices. All implementations (numpy, scipy, and our JAX implementations) struggle with low-precision + ill-conditioned matrices.

I think that maybe it would be better to re-purpose the second test test_no_solution to test_tol which will focus on testing the tolerance argument against ill-conditioned matrices.

Does this sound like a viable path forward?

Dekermanjian avatar Jun 12 '25 11:06 Dekermanjian

Thanks again for your patience @jakevdp. I made the suggested changes and re-pushed.

Dekermanjian avatar Jun 18 '25 11:06 Dekermanjian

Hi @jakevdp, I resolved the merge conflicts by rebasing main. Anything else I can do here to improve this addition?

Dekermanjian avatar Jun 26 '25 12:06 Dekermanjian