jax icon indicating copy to clipboard operation
jax copied to clipboard

`jnp.frexp` not matching Numpy on sub-normals

Open balancap opened this issue 6 months ago • 2 comments

Description

The function jnp.frexp does not match Numpy result on sub-normals (I would personally consider the Numpy result to be the more accurate one).

import numpy as np
import jax.numpy as jnp

v = np.finfo(np.float16).smallest_subnormal
np.frexp(v)     # returns (0.5, -23)
jnp.frexp(v)    # returns (Array(0.5005, dtype=float16), Array(-14, dtype=int32))

A similar issue exists with FP32 too.

What jax/jaxlib version are you using?

jax 0.4.13, jaxlib 0.4.13

Which accelerator(s) are you using?

CPU

Additional system info?

Python 3.8

NVIDIA GPU info

No response

balancap avatar Feb 07 '24 15:02 balancap

Thanks for the report – I think this is expected, because XLA flushes subnormal values during operations, because hardware like TPU and some GPU devices do not support them

jakevdp avatar Feb 07 '24 18:02 jakevdp

I had a quick check yesterday: this piece of code runs fine and returns the proper result on recent Nvidia GPUs (at least ML ones):

val_f16 = np.finfo(np.float16).smallest_subnormal
val_f32 = np.finfo(np.float32).smallest_subnormal

@jax.jit
def fn(v):
  return v * 2

out_f16 = fn(val_f16)
out_f32 = fn(val_f32)

print(out_f16 == (val_f16 * 2), out_f16, out_f16, out_f16.dtype)
print(out_f32 == (val_f32 * 2), out_f32, out_f32, out_f32.dtype)

On TPU, it indeed flushes to zero.

It's not necessarily an easily decision, but my take would be that frexp is just a mantissa + exponent split operation, i.e. closer to a bitwise masking manipulation than an arithmetic op (add, mul, ...). Hence I would expect that the result should the same accross platforms.

balancap avatar Feb 14 '24 09:02 balancap