jax icon indicating copy to clipboard operation
jax copied to clipboard

Copy nn.{softmax,log_softmax} to scipy.special

Open NeilGirdhar opened this issue 1 year ago • 9 comments

Fixes #20700

NeilGirdhar avatar Apr 12 '24 13:04 NeilGirdhar

As requested by @jakevdp

NeilGirdhar avatar Apr 12 '24 13:04 NeilGirdhar

It looks like these functions can't be imported directly, because jax.nn.softmax defaults to axis=-1, while scipy.special.softmax defaults to axis=None.

We'll have to create wrappers for these in jax/_src/scipy/special.py.

jakevdp avatar Apr 12 '24 18:04 jakevdp

It looks like these functions can't be imported directly, because jax.nn.softmax defaults to axis=-1, while scipy.special.softmax defaults to axis=None.

We'll have to create wrappers for these in jax/_src/scipy/special.py.

Ah, got it, so this is more tricky than I thought.

NeilGirdhar avatar Apr 12 '24 19:04 NeilGirdhar

Hello @NeilGirdhar, @jakevdp,

I have noticed a (serious) bug with jax.nn.softmax which might be worth fixing within this PR.

import jax
import jax.numpy as jnp

x = jax.random.normal(jax.random.key(0), (3, 1, 1))

def f(x):
    return jax.nn.softmax(x, axis=0)

print(f(x))
print(jax.jit(f)(x))
[[[0.75250584]]

 [[0.0755428 ]]

 [[0.17195134]]]
[[[1.]]

 [[1.]]

 [[1.]]]

Note that the shape of x seems to cause the bug within jax.jit(f).

Using a "numerically safe" version of softmax based on log_softmax solves the issue.

def f(x):
    return jnp.exp(jax.nn.log_softmax(x, axis=0))

print(f(x))
print(jax.jit(f)(x))
[[[0.75250584]]

 [[0.07554279]]

 [[0.17195135]]]
[[[0.75250584]]

 [[0.07554279]]

 [[0.17195135]]]

francois-rozet avatar Apr 21 '24 12:04 francois-rozet

@francois-rozet I think you should file a separate bug? Looks like the axis is being misinterpreted, but that's just a guess.

NeilGirdhar avatar Apr 21 '24 13:04 NeilGirdhar

@NeilGirdhar Ok I filled a bug report (#20856).

francois-rozet avatar Apr 21 '24 13:04 francois-rozet

We'll have to create wrappers for these in jax/_src/scipy/special.py.

Also, we might want to make the jax.scipy.special wrappers use the non-deprecated softmax version by default. The deprecated version has some poor autodiff behavior, but we haven't been able to turn it on by default in jax.nn.softmax because it changes the numerics and breaks some downstream projects that are sensitive to the change.

jakevdp avatar Apr 21 '24 14:04 jakevdp

Are you still interested in contributing this?

jakevdp avatar May 16 '24 18:05 jakevdp

@jakevdp Yes, sorry for the delay. I've been trying to find time to do it.

NeilGirdhar avatar May 22 '24 04:05 NeilGirdhar

@jakevdp should be ready for your review.

NeilGirdhar avatar Jun 22 '24 05:06 NeilGirdhar

Thanks for the speedy and patient review!

NeilGirdhar avatar Jun 22 '24 14:06 NeilGirdhar