jax icon indicating copy to clipboard operation
jax copied to clipboard

jnp.linalg.norm() underflows

Open hawkinsp opened this issue 4 years ago • 4 comments

Simple example:

jnp.linalg.norm(jnp.array([1e-20], jnp.float32))

produces 0.

Denormals are flushed to zero, and 1e-20**2 is smaller than the smallest normal float32.

It's possible to define a non-underflowing Euclidean norm using variadic reductions:

def norm2(xs):
  def reducer(xs, ys):
    (ssq_x, scale_x) = xs
    (ssq_y, scale_y) = ys
    scale = lax.max(scale_x, scale_y)
    scale_is_zero = lax.eq(scale, jnp.float32(0))
    ssq = ssq_x * (scale_x/scale)** 2 + ssq_y * (scale_y / scale)**2
    return (lax.select(scale_is_zero, jnp.ones_like(ssq), ssq), scale)


  ssq_y, scale_y = lax.reduce(
    (np.ones_like(x), jnp.abs(x)),
    (jnp.float32(1), jnp.float32(0)),
    reducer, dimensions=(0,))
  return scale_y * jnp.sqrt(ssq_y)

We should determine whether we want to do this; if so it probably needs to be implemented as a primitive.

Interestingly numpy.linalg.norm also produces 0 if the CPU ftz flag is set (e.g., via the daz package), but scipy.linalg.norm does not.

hawkinsp avatar Feb 23 '21 16:02 hawkinsp

Rasmus pointed me to https://hal.archives-ouvertes.fr/hal-01511120/document which looks interesting.

hawkinsp avatar Feb 23 '21 19:02 hawkinsp

Another possibility is to rescale the norm by the element with the largest absolute value. This would require a second reduction, which perhaps rules it out for performance reasons.

hawkinsp avatar Apr 20 '22 14:04 hawkinsp

@tlu7

tlu7 avatar May 05 '22 05:05 tlu7

Hi @hawkinsp

I tested the provided reproducible code on Google colab with JAX version 0.4.26. Now jnp.linalg.norm does not produce 0 for 1e-20.

import jax.numpy as jnp
import numpy as np
import scipy


print("JAX  :", jnp.linalg.norm(jnp.array([1e-20], jnp.float32)))
print("NumPy:", np.linalg.norm(np.array([1e-20], np.float32)))
print("SciPy:", scipy.linalg.norm(np.array([1e-20], np.float32)))

output:

JAX  : 1e-20
NumPy: 9.999973e-21
SciPy: 9.999999682655225e-21

For 1e-20**2:

JAX  : 1e-40
NumPy: 0.0
SciPy: 9.99994610111476e-41

Please find the gist for reference.

Thank you.

rajasekharporeddy avatar Apr 29 '24 04:04 rajasekharporeddy

Thanks for the followup!

jakevdp avatar May 15 '24 21:05 jakevdp