jax
jax copied to clipboard
Wrong output of `jax.scipy.special.sph_harm`
Description
There is a wrong output of jax.scipy.special.sph_harm(m, n, theta, phi, n_max=None)
, when the degree of the harmonic $n \neq 0$.
Here is an example:
import jax
import jax.numpy as jnp
from jax.scipy.special import sph_harm as jnp_sph
from scipy.special import sph_harm
# Generate 200 3D points
seed = 23
key = jax.random.PRNGKey(seed)
key, subkey = jax.random.split(key, 2)
data = jax.random.normal(subkey, shape=(200,3))
r = jnp.linalg.norm(data, ord=2, axis=1)
phi = jnp.array(jnp.arccos(data[:,2]/r))
theta = jnp.array(jnp.arctan2(data[:,1],data[:,0]))
# Calculate spa_harm value of Jax and scipy
m = 0
n = 1
scipy_result = sph_harm(jnp.array([m]), jnp.array([n]), theta, phi)
jax_result = jnp_sph(jnp.array([m]), jnp.array([n]), theta, phi, n_max=n)
print(jnp.max(jnp.abs(scipy_result - jax_result)))
The return value should be close to zero, but the real return is 0.8381599
. When $m = 0, n = 0$, the return is 2.9802322e-08
is correct.
I check the source code of jax.scipy.special.sph_harm
and find the wrong maybe is here:
@partial(jit, static_argnums=(4,))
def _sph_harm(m: Array,
n: Array,
theta: Array,
phi: Array,
n_max: int) -> Array:
"""Computes the spherical harmonics."""
cos_colatitude = jnp.cos(phi)
legendre = _gen_associated_legendre(n_max, cos_colatitude, True)
legendre_val = legendre.at[abs(m), n, jnp.arange(len(n))].get(mode="clip")
angle = abs(m) * theta
vandermonde = lax.complex(jnp.cos(angle), jnp.sin(angle))
harmonics = lax.complex(legendre_val * jnp.real(vandermonde),
legendre_val * jnp.imag(vandermonde))
# Negative order.
harmonics = jnp.where(m < 0,
(-1.0)**abs(m) * jnp.conjugate(harmonics),
harmonics)
return harmonics
This statement of Legendre_val
used the wrong array, which should be changed to
legendre_val = legendre.at[abs(m), n, jnp.arange(len(phi))].get(mode="clip")
The reason of function value is correct when degree $n=0$, is at that degree, the Legendre polynomial is a constant, so every value in legendre.at[abs(m), n, jnp.arange(len(phi))]
is as same as the value of legendre.at[abs(m), n, jnp.arange(len(n))]
. But when degree $n\neq 0$, the value will be wrong.
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.7 jaxlib: 0.4.7 numpy: 1.22.4 python: 3.8.16 (default, Mar 1 2023, 21:19:10) [Clang 14.0.6 ] jax.devices (1 total, 1 local): [CpuDevice(id=0)] process_count: 1
Hi @SGENZO
Thank for reporting the bug. I have opened a PR #20772 on this. This issue will be closed once the PR is merged.
Thanks – the sph_harm
code is unfortunately very poorly implemented and we're considering removing it entirely (see https://jax.readthedocs.io/en/latest/jep/18137-numpy-scipy-scope.html#scipy-special). We can probably fix this bug, but long-term I'd suggest finding a different implementation to rely on.
@jakevdp Got it. I'll consider other implementation, thanks.