jax
jax copied to clipboard
jnp.linalg.norm() underflows
Simple example:
jnp.linalg.norm(jnp.array([1e-20], jnp.float32))
produces 0.
Denormals are flushed to zero, and 1e-20**2 is smaller than the smallest normal float32.
It's possible to define a non-underflowing Euclidean norm using variadic reductions:
def norm2(xs):
def reducer(xs, ys):
(ssq_x, scale_x) = xs
(ssq_y, scale_y) = ys
scale = lax.max(scale_x, scale_y)
scale_is_zero = lax.eq(scale, jnp.float32(0))
ssq = ssq_x * (scale_x/scale)** 2 + ssq_y * (scale_y / scale)**2
return (lax.select(scale_is_zero, jnp.ones_like(ssq), ssq), scale)
ssq_y, scale_y = lax.reduce(
(np.ones_like(x), jnp.abs(x)),
(jnp.float32(1), jnp.float32(0)),
reducer, dimensions=(0,))
return scale_y * jnp.sqrt(ssq_y)
We should determine whether we want to do this; if so it probably needs to be implemented as a primitive.
Interestingly numpy.linalg.norm also produces 0 if the CPU ftz flag is set (e.g., via the daz package), but scipy.linalg.norm does not.
Rasmus pointed me to https://hal.archives-ouvertes.fr/hal-01511120/document which looks interesting.
Another possibility is to rescale the norm by the element with the largest absolute value. This would require a second reduction, which perhaps rules it out for performance reasons.
@tlu7
Hi @hawkinsp
I tested the provided reproducible code on Google colab with JAX version 0.4.26. Now jnp.linalg.norm does not produce 0 for 1e-20.
import jax.numpy as jnp
import numpy as np
import scipy
print("JAX :", jnp.linalg.norm(jnp.array([1e-20], jnp.float32)))
print("NumPy:", np.linalg.norm(np.array([1e-20], np.float32)))
print("SciPy:", scipy.linalg.norm(np.array([1e-20], np.float32)))
output:
JAX : 1e-20
NumPy: 9.999973e-21
SciPy: 9.999999682655225e-21
For 1e-20**2:
JAX : 1e-40
NumPy: 0.0
SciPy: 9.99994610111476e-41
Please find the gist for reference.
Thank you.
Thanks for the followup!