jaxoplanet
jaxoplanet copied to clipboard
feat: s2fft wigner matrices
I copied some unreleased code from s2fft (only temporary, I opened an issue there) to test the recurrence building of the Wigner D-matrices. I did a modification so that beta
in their utils.rotation.generate_rotate_dls
can be non-static.
Compilation is orders of magnitude faster:
-
8 seconds for (new)
from jaxoplanet.experimental.starry.s2fft_rotation import compute_rotation_matrices as R1 deg = 20 R = R1(deg, 0.0, 1.0, 0.0, 1.0)
-
3 minutes for (current)
from jaxoplanet.experimental.starry.rotation import compute_rotation_matrices as R2 R = R2(deg, 0.0, 1.0, 0.0, 1.0)
Of course this is preliminary and should be tested more extensively. Just a note that for spherical harmonics with a degreeL
, the output of s2fft.utils.rotation.generate_rotate_dls
is (L, 2*L+1, 2*L+1)
(whereas we currently have a list of matrices with different shapes [(1,), (3,3), (5,5), ..., (2*L+1, 2*L+1)]
.
To exploit that, I tried to pad and unpad the Ylm
to perform actual matrix multiplications. It doesn't seems to make things faster, which isn't that surprising. Anyway I'll push that bit in the Ylm
class in case it's useful.
Looking forward doing more testings!
This is great @lgrcia!! I'll take a closer look later.
I did a modification so that beta in their utils.rotation.generate_rotate_dls can be non-static.
Is this something where you could contribute upstream or is it easier to keep it local?
I opened this issue on the s2fft repo to discuss that there! Just to be sure about the motives behind the static beta.
Just to keep a reference somewhere, here is the version with the padded Ylm
:
import jax
from jaxoplanet.experimental.starry import Ylm, rotation
import numpy as np
import jax.numpy as jnp
deg = 10
y = Ylm.from_dense(np.hstack([1, np.random.rand((deg + 1) ** 2 - 1)]))
def dot_rotation_matrix2(ydeg, x, y, z, theta):
rotation_matrices = rotation.compute_rotation_matrices_s2fft(ydeg, x, y, z, theta)
def dot(y_padded):
padded_y = jnp.einsum("ij,ijk->ik", y_padded, rotation_matrices)
return padded_y
return dot
inc, obl = 0.5, 0.8
values = rotation.right_project_axis_angle(inc, obl, 0.0, 0.0)
f = jax.jit(dot_rotation_matrix2(deg, *values))
rotated_y = Ylm.from_dense_pad(f(y.to_dense_pad())).todense()
This is not faster than working with a non-homogeneous set of rotation matrices.
Based on the answer from @CosmoMatt in s2fft
's issue #191, I tried the following implementation based on equation 8 of this paper:
from s2fft.utils.rotation import generate_rotate_dls
import jax.numpy as jnp
def new_dls(deg):
delta = generate_rotate_dls(deg, jnp.pi / 2)
idxs = jnp.indices(delta[0].shape)
i = idxs[1][0] - deg + 1
inm = 1j ** (idxs[1] - idxs[0])
def impl(beta):
sum_term = jnp.einsum("nij,nik,i->njk", delta, delta, jnp.exp(1j * i * beta))
m = sum_term * inm
return m.real
return impl
which seems correct and passes all sort of tests with different degrees and $\beta$. However the performance I get is less that what I would have expected. The original generate_rotate_dls
gives
import jax
from functools import partial
deg = 20
beta = 1.254
betas = jnp.linspace(0, 2 * jnp.pi, 4000)
f = jax.jit(jax.vmap(partial(generate_rotate_dls, deg)))
%timeit jax.block_until_ready(f(betas))
257 ms ± 1.57 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
whereas the version based on a single call of generate_rotate_dls
gives
f = jax.jit(jax.vmap(new_dls(deg)))
%timeit jax.block_until_ready(f(betas))
366 ms ± 6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
@dfm, @CosmoMatt, do you have any intuition why the vmap
over $\beta$ in generate_rotate_dls
still provides very good performance compared to the implementation I tested? (which might involve lots of consuming matrices multiplications at the end). I would have expected the recursive construction of the dls
to be expensive. I might be missing some jax
behavior here.
Hi @lgrcia, sorry I couldn't look at this sooner. Unfortunately, I can't see any obvious reason why the three term einsum
in impl
should be slower than the full recursion. I wonder if this may be a scaling issue that I overlooked in my previous response here.
-
The brute force approach requires O(L^3) per beta for a full complexity of O(beta * L^3) to compute all elements.
-
The FFT approach for each
beta
requires computing allel, m, n
entries which each involve a summation of lengthL
with complexity O(L). So the complexity for the FFT approach would presumably be O(beta * L^4).
I may be missing something here though, perhaps @jasonmcewen has a better answer here.
Thanks a lot for your answer @CosmoMatt! That is a really interesting avenue to consider anyway!
Superseded by #212