jax icon indicating copy to clipboard operation
jax copied to clipboard

Wrong output of `jax.scipy.special.sph_harm`

Open SGENZO opened this issue 10 months ago • 3 comments

Description

There is a wrong output of jax.scipy.special.sph_harm(m, n, theta, phi, n_max=None), when the degree of the harmonic $n \neq 0$. Here is an example:

import jax 
import jax.numpy as jnp
from jax.scipy.special import sph_harm as jnp_sph
from scipy.special import sph_harm

# Generate 200 3D points
seed = 23
key = jax.random.PRNGKey(seed)
key, subkey = jax.random.split(key, 2)
data = jax.random.normal(subkey, shape=(200,3))
r = jnp.linalg.norm(data, ord=2, axis=1)
phi = jnp.array(jnp.arccos(data[:,2]/r))
theta = jnp.array(jnp.arctan2(data[:,1],data[:,0]))

# Calculate spa_harm value of Jax and scipy
m = 0
n = 1
scipy_result = sph_harm(jnp.array([m]), jnp.array([n]), theta, phi)
jax_result = jnp_sph(jnp.array([m]), jnp.array([n]), theta, phi, n_max=n)
print(jnp.max(jnp.abs(scipy_result - jax_result)))

The return value should be close to zero, but the real return is 0.8381599. When $m = 0, n = 0$, the return is 2.9802322e-08 is correct.

I check the source code of jax.scipy.special.sph_harm and find the wrong maybe is here:

@partial(jit, static_argnums=(4,))
def _sph_harm(m: Array,
              n: Array,
              theta: Array,
              phi: Array,
              n_max: int) -> Array:
  """Computes the spherical harmonics."""

  cos_colatitude = jnp.cos(phi)

  legendre = _gen_associated_legendre(n_max, cos_colatitude, True)
  legendre_val = legendre.at[abs(m), n, jnp.arange(len(n))].get(mode="clip")

  angle = abs(m) * theta
  vandermonde = lax.complex(jnp.cos(angle), jnp.sin(angle))
  harmonics = lax.complex(legendre_val * jnp.real(vandermonde),
                          legendre_val * jnp.imag(vandermonde))

  # Negative order.
  harmonics = jnp.where(m < 0,
                        (-1.0)**abs(m) * jnp.conjugate(harmonics),
                        harmonics)

  return harmonics

This statement of Legendre_val used the wrong array, which should be changed to legendre_val = legendre.at[abs(m), n, jnp.arange(len(phi))].get(mode="clip")

The reason of function value is correct when degree $n=0$, is at that degree, the Legendre polynomial is a constant, so every value in legendre.at[abs(m), n, jnp.arange(len(phi))] is as same as the value of legendre.at[abs(m), n, jnp.arange(len(n))] . But when degree $n\neq 0$, the value will be wrong.

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.7 jaxlib: 0.4.7 numpy: 1.22.4 python: 3.8.16 (default, Mar 1 2023, 21:19:10) [Clang 14.0.6 ] jax.devices (1 total, 1 local): [CpuDevice(id=0)] process_count: 1

SGENZO avatar Apr 16 '24 05:04 SGENZO

Hi @SGENZO

Thank for reporting the bug. I have opened a PR #20772 on this. This issue will be closed once the PR is merged.

rajasekharporeddy avatar Apr 16 '24 13:04 rajasekharporeddy

Thanks – the sph_harm code is unfortunately very poorly implemented and we're considering removing it entirely (see https://jax.readthedocs.io/en/latest/jep/18137-numpy-scipy-scope.html#scipy-special). We can probably fix this bug, but long-term I'd suggest finding a different implementation to rely on.

jakevdp avatar Apr 16 '24 14:04 jakevdp

@jakevdp Got it. I'll consider other implementation, thanks.

SGENZO avatar Apr 17 '24 00:04 SGENZO