jax icon indicating copy to clipboard operation
jax copied to clipboard

nan in hessian of linear function

Open qbadolle opened this issue 4 months ago • 4 comments

Description

I am trying to vectorize a procedure, where I need to do the element-wise power of a large vector before getting its Hessian.

If I manually compute the element-wise power of the vector in order to define the function of interest, I can use this code which gives the expected output:

def fun(x):
    x = jnp.array([x[0]])
    return x
hess = jacfwd(jacfwd(fun))
x = jnp.array([0.])
hess_val = hess(x)
print(hess_val) # output: [[[0.]]]

However, if my function of interest takes the element-wise power as part of the definition, I get nan:

def fun(x):
    x = x ** jnp.array([1])
    return x
hess = jacfwd(jacfwd(fun))
x = jnp.array([0.])
hess_val = hess(x)
print(hess_val)  # output: [[[nan]]]

This is related to issue #18640 but different as the problem only arises with the Hessian (the Jacobian gives the expected output using both function definitions).

Would you have any hint why this happens? Thank you!

System info (python version, jaxlib version, accelerator, etc.)

jax version: 0.4.25 jaxlib version: 0.4.23

qbadolle avatar Apr 03 '24 21:04 qbadolle

Thanks for the report! I think you're corrrect that this is related to #18640. The reason it only affects the hessian and not the jacobian is because the hessian is a second-order differentiation. Your function looks like this:

$$ f(x) = x^1 $$

Differentiating once, you get this:

$$ f^\prime(x) = x^0 $$

And differentiating a second time, you are differentiating $0^0$, which is ill-defined for floating point exponent, and so we return NaN.

Note that if you use a scalar integer exponent, the derivative is well-defined and the fix referenced in #18640 will apply to your hessian as well:

jax.hessian(lambda x: x ** 1)(x)
# Array([[[0.]]], dtype=float32)

But for non-scalar, non-integer exponentiation, nan is the correct output because the value of the derivative is ambiguous (see the discussion in https://github.com/google/jax/issues/14397#issuecomment-1426386290).

jakevdp avatar Apr 04 '24 14:04 jakevdp

Thank you for your reply, @jakevdp.

I am afraid I do not understand why, in the second code snippet above, jnp.array([1]) is treated as a floating point exponent. I just checked and jnp.array([1]) has dtype=int32 so I would have expected JAX to treat it as an integer exponent accordingly.

As the end-goal is to take the element-wise power of a large vector: I could of course take the scalar, integer exponent of each element in the vector instead of using ** but I would expect the looping through all the elements in the vector to make the process rather inefficient. Is there no better way?

Thank you again.

qbadolle avatar Apr 04 '24 16:04 qbadolle

x ** jnp.array([1]) lowers to lax.pow, which has float-power semantics, while x ** 1 lowers to lax.integer_pow, which has integer-power semantics. We could maybe specialize the autodiff rules of pow_p to be dtype-dependent, but we haven't done that. @mattjj, what do you think?

jakevdp avatar Apr 04 '24 16:04 jakevdp

Thank you for following up, @jakevdp.

@mattjj: I would be happy to make it a pull request but I would most likely need some guidance given that I am not familiar with all the details of JAX yet. What would you suggest?

qbadolle avatar May 02 '24 13:05 qbadolle