jaxoplanet icon indicating copy to clipboard operation
jaxoplanet copied to clipboard

feat: s2fft wigner matrices

Open lgrcia opened this issue 1 year ago • 6 comments

I copied some unreleased code from s2fft (only temporary, I opened an issue there) to test the recurrence building of the Wigner D-matrices. I did a modification so that beta in their utils.rotation.generate_rotate_dls can be non-static.

Compilation is orders of magnitude faster:

  • 8 seconds for (new)

    from jaxoplanet.experimental.starry.s2fft_rotation import compute_rotation_matrices as R1
    
    deg = 20
    R = R1(deg, 0.0, 1.0, 0.0, 1.0)
    
  • 3 minutes for (current)

    from jaxoplanet.experimental.starry.rotation import compute_rotation_matrices as R2
    
    R = R2(deg, 0.0, 1.0, 0.0, 1.0)
    

Of course this is preliminary and should be tested more extensively. Just a note that for spherical harmonics with a degreeL, the output of s2fft.utils.rotation.generate_rotate_dls is (L, 2*L+1, 2*L+1) (whereas we currently have a list of matrices with different shapes [(1,), (3,3), (5,5), ..., (2*L+1, 2*L+1)].

To exploit that, I tried to pad and unpad the Ylm to perform actual matrix multiplications. It doesn't seems to make things faster, which isn't that surprising. Anyway I'll push that bit in the Ylm class in case it's useful.

Looking forward doing more testings!

lgrcia avatar Feb 27 '24 01:02 lgrcia

This is great @lgrcia!! I'll take a closer look later.

I did a modification so that beta in their utils.rotation.generate_rotate_dls can be non-static.

Is this something where you could contribute upstream or is it easier to keep it local?

dfm avatar Feb 27 '24 15:02 dfm

I opened this issue on the s2fft repo to discuss that there! Just to be sure about the motives behind the static beta.

lgrcia avatar Feb 27 '24 18:02 lgrcia

Just to keep a reference somewhere, here is the version with the padded Ylm:

import jax
from jaxoplanet.experimental.starry import Ylm, rotation
import numpy as np
import jax.numpy as jnp

deg = 10
y = Ylm.from_dense(np.hstack([1, np.random.rand((deg + 1) ** 2 - 1)]))


def dot_rotation_matrix2(ydeg, x, y, z, theta):
    rotation_matrices = rotation.compute_rotation_matrices_s2fft(ydeg, x, y, z, theta)

    def dot(y_padded):
        padded_y = jnp.einsum("ij,ijk->ik", y_padded, rotation_matrices)
        return padded_y

    return dot


inc, obl = 0.5, 0.8
values = rotation.right_project_axis_angle(inc, obl, 0.0, 0.0)
f = jax.jit(dot_rotation_matrix2(deg, *values))
rotated_y = Ylm.from_dense_pad(f(y.to_dense_pad())).todense()

This is not faster than working with a non-homogeneous set of rotation matrices.

lgrcia avatar Feb 27 '24 19:02 lgrcia

Based on the answer from @CosmoMatt in s2fft's issue #191, I tried the following implementation based on equation 8 of this paper:

from s2fft.utils.rotation import generate_rotate_dls
import jax.numpy as jnp


def new_dls(deg):
    delta = generate_rotate_dls(deg, jnp.pi / 2)
    idxs = jnp.indices(delta[0].shape)
    i = idxs[1][0] - deg + 1
    inm = 1j ** (idxs[1] - idxs[0])

    def impl(beta):
        sum_term = jnp.einsum("nij,nik,i->njk", delta, delta, jnp.exp(1j * i * beta))
        m = sum_term * inm
        return m.real

    return impl

which seems correct and passes all sort of tests with different degrees and $\beta$. However the performance I get is less that what I would have expected. The original generate_rotate_dls gives

import jax
from functools import partial

deg = 20
beta = 1.254
betas = jnp.linspace(0, 2 * jnp.pi, 4000)

f = jax.jit(jax.vmap(partial(generate_rotate_dls, deg)))
%timeit jax.block_until_ready(f(betas))
257 ms ± 1.57 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

whereas the version based on a single call of generate_rotate_dls gives

f = jax.jit(jax.vmap(new_dls(deg)))
%timeit jax.block_until_ready(f(betas))
366 ms ± 6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

@dfm, @CosmoMatt, do you have any intuition why the vmap over $\beta$ in generate_rotate_dls still provides very good performance compared to the implementation I tested? (which might involve lots of consuming matrices multiplications at the end). I would have expected the recursive construction of the dls to be expensive. I might be missing some jax behavior here.

lgrcia avatar Mar 22 '24 21:03 lgrcia

Hi @lgrcia, sorry I couldn't look at this sooner. Unfortunately, I can't see any obvious reason why the three term einsum in impl should be slower than the full recursion. I wonder if this may be a scaling issue that I overlooked in my previous response here.

  • The brute force approach requires O(L^3) per beta for a full complexity of O(beta * L^3) to compute all elements.

  • The FFT approach for each beta requires computing all el, m, n entries which each involve a summation of length L with complexity O(L). So the complexity for the FFT approach would presumably be O(beta * L^4).

I may be missing something here though, perhaps @jasonmcewen has a better answer here.

CosmoMatt avatar Apr 01 '24 09:04 CosmoMatt

Thanks a lot for your answer @CosmoMatt! That is a really interesting avenue to consider anyway!

lgrcia avatar Apr 02 '24 15:04 lgrcia

Superseded by #212

lgrcia avatar Sep 09 '24 19:09 lgrcia