jaxoplanet icon indicating copy to clipboard operation
jaxoplanet copied to clipboard

feat: rotation matrices using s2fft

Open lgrcia opened this issue 1 year ago • 6 comments

Using the s2fft Python package to compute the Wigner D-matrices used to rotate the spherical harmonics. See also #140

lgrcia avatar Sep 09 '24 19:09 lgrcia

To solve a numpy import issue in s2fft (see s2fft#206) I had to impose numpy<2.0. Not sure we want to stick to that.

lgrcia avatar Sep 12 '24 21:09 lgrcia

I can't figure why the macos-python3.11 test is hanging on until being timed out... Any help would be welcome!

lgrcia avatar Sep 17 '24 20:09 lgrcia

I can't figure why the macos-python3.11 test is hanging on until being timed out

Strange! You could try removing the -n auto from the pytest command? Does it work ok locally on your mac?

dfm avatar Sep 21 '24 20:09 dfm

I think maybe github actions were actually just having issues that day. I've tried re-running the job. Let's see if that works!

dfm avatar Sep 21 '24 20:09 dfm

Re-running the job failed. I can reproduce the problem locally but I can't locate from which test the problem is coming. Individual tests are passing locally but freeze when run together. Could it be a memory issue? I'm clueless...

The major change is enforcing numpy<2.0.

lgrcia avatar Sep 26 '24 18:09 lgrcia

@lgrcia — I tried turning off parallel execution of the tests and that changed the behavior to just hang indefinitely when running the test_light_curves_orders test (I think). This isn't failing on the main branch so it must have something to do with s2fft. I can't imagine the issue has to do with numpy (although pinning it <2.0 seems like a Bad Idea™ - hopefully they can fix the issues soon!). It might be worth trying some different versions of JAX (e.g. <=0.4.31) and s2fft to see if you can narrow down to a combination that works.

dfm avatar Sep 29 '24 14:09 dfm

The issue here must be related to s2fft! @lgrcia, can you try to work out which inputs we're passing in the test that hangs to isolate exactly what s2fft call we're executing. It would be interesting to know if we can get the same hang using just s2fft and no jaxoplanet. In that case we can report upstream and see what they say.

dfm avatar Oct 05 '24 14:10 dfm

I think I identified where the problem is from.

Here is a way to reproduce it on macos
import numpy as np
from jaxoplanet.experimental.starry.rotation import dot_rotation_matrix

l_max = 5
theta = 0
ident = np.eye(l_max**2 + 2 * l_max + 1)
expected = dot_rotation_matrix(l_max, 0.0, 0.0, 1.0, theta)(ident)
calc = dot_rotation_matrix(l_max, None, None, 1.0, theta)(ident)

This runs ok. But then, when l_max is changed in the same runtime:

l_max = 6
ident = np.eye(l_max**2 + 2 * l_max + 1)
expected = dot_rotation_matrix(l_max, 0.0, 0.0, 1.0, theta)(ident)

it freezes.

I think it has to do in how I (or s2fft) combine the static arguments in the jitted functions from s2fft.utils.rotation and s2fft.utils (see also the jitted signature from jaxoplanet.experimental.starry.s2fft_rotation.py:compute_rotation_matrices). But I honestly don't totally get why it is only a problem on macos.

I don't really understand why it behaves like this but a workaround for me is to avoid decorating the s2fft rotation functions with jit. So I copied all required functions (we only need 100 lines of python from s2fft) and removed the s2fft dependency, for now.

I'm down to understand the problem better before reintroducing s2fft as a dependency.

lgrcia avatar Oct 08 '24 19:10 lgrcia

The issue here must be related to s2fft! @lgrcia, can you try to work out which inputs we're passing in the test that hangs to isolate exactly what s2fft call we're executing. It would be interesting to know if we can get the same hang using just s2fft and no jaxoplanet. In that case we can report upstream and see what they say.

@dfm, here is a way to reproduce only with s2fft:

import jax
from functools import partial
from s2fft.utils.rotation import generate_rotate_dls


@partial(jax.jit, static_argnums=(0,))
def f(deg, alpha):
    return generate_rotate_dls(deg, alpha)


_ = f(5, 0.0)  # this executes fine
_ = f(10, 0.0)  # this freezes

I might open an issue but I think this is not a proper use of this function given the static args (see https://github.com/astro-informatics/s2fft/blob/main/s2fft/utils/rotation.py#L75). Any idea why this would happen?

I understand that each test run in separate python instances should pass. So could the issue be due to how pytest runners are dispatched on macOS?

lgrcia avatar Oct 08 '24 20:10 lgrcia

It's fascinating to me that that happens and that your solution works! I don't see any reason why the beta parameter should be labelled as static, but it also seems like it should crash to nest these static_argnums incompatibly like this...

Regardless: I think this is a good "fix"!

dfm avatar Oct 10 '24 00:10 dfm

@dfm, do you think we are ready to merge?

lgrcia avatar Oct 15 '24 20:10 lgrcia

Just for reference, the s2fft addition was merged by mistake in #225. Thanks for all the reviews on this!!

lgrcia avatar Oct 16 '24 14:10 lgrcia