jax
jax copied to clipboard
Fix _sph_harm to agree with scipy results.
Before the fix, the following code to compute spherical harmonics produces mostly an array filled with zeros (see sph.pdf) and does not agree with the scipy results. After the change, it does (see sph_fixed.pdf).
To reproduce:
import jax.numpy as jnp
import jax.scipy.special as jp
import matplotlib.pyplot as plt
import numpy as np
import scipy.special as sp
l_max = jnp.array([2])
m_values = jnp.array([-2, -1, 0, 1, 2])
theta = jnp.linspace(0, 2 * np.pi, 1000)
phi = jnp.linspace(0, np.pi, 1000)
# Calculate the spherical harmonics
Y_lm = []
Y_lm_scipy = []
for m in m_values:
Y_lm.append(jp.sph_harm(m, l_max, theta, phi))
Y_lm_scipy.append(sp.sph_harm(m, l_max, theta, phi))
Ylm = jnp.array(Y_lm)
Ysci = jnp.array(Y_lm_scipy)
fig, ax = plt.subplots()
for idx in range(Ylm.shape[0]):
ax.plot(Ylm.real[idx], "--", label="Jax", c=f"C{idx}", alpha=0.5)
ax.plot(Ysci.real[idx], "-", label="Scipy", c=f"C{idx}", alpha=0.5)
ax.legend()
fig.savefig("sph_fixed.pdf")
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).
View this failed invocation of the CLA check for more information.
For the most up to date status, view the checks section at the bottom of the pull request.
Thanks for the fix! Could you also add a regression test for this new behavior?
@tlu7, could you take a look?
@jakevdp Just fixed the test!
So I'm not totally familiar with this code: could you explain the fix and why the updated test addresses the problem?
Is there anything blocking this PR from merging? It seems like the spherical harmonic implementation is still broken on the latest version of JAX