jax
jax copied to clipboard
Copy nn.{softmax,log_softmax} to scipy.special
Fixes #20700
As requested by @jakevdp
It looks like these functions can't be imported directly, because jax.nn.softmax defaults to axis=-1, while scipy.special.softmax defaults to axis=None.
We'll have to create wrappers for these in jax/_src/scipy/special.py.
It looks like these functions can't be imported directly, because
jax.nn.softmaxdefaults toaxis=-1, whilescipy.special.softmaxdefaults toaxis=None.We'll have to create wrappers for these in
jax/_src/scipy/special.py.
Ah, got it, so this is more tricky than I thought.
Hello @NeilGirdhar, @jakevdp,
I have noticed a (serious) bug with jax.nn.softmax which might be worth fixing within this PR.
import jax
import jax.numpy as jnp
x = jax.random.normal(jax.random.key(0), (3, 1, 1))
def f(x):
return jax.nn.softmax(x, axis=0)
print(f(x))
print(jax.jit(f)(x))
[[[0.75250584]]
[[0.0755428 ]]
[[0.17195134]]]
[[[1.]]
[[1.]]
[[1.]]]
Note that the shape of x seems to cause the bug within jax.jit(f).
Using a "numerically safe" version of softmax based on log_softmax solves the issue.
def f(x):
return jnp.exp(jax.nn.log_softmax(x, axis=0))
print(f(x))
print(jax.jit(f)(x))
[[[0.75250584]]
[[0.07554279]]
[[0.17195135]]]
[[[0.75250584]]
[[0.07554279]]
[[0.17195135]]]
@francois-rozet I think you should file a separate bug? Looks like the axis is being misinterpreted, but that's just a guess.
@NeilGirdhar Ok I filled a bug report (#20856).
We'll have to create wrappers for these in
jax/_src/scipy/special.py.
Also, we might want to make the jax.scipy.special wrappers use the non-deprecated softmax version by default. The deprecated version has some poor autodiff behavior, but we haven't been able to turn it on by default in jax.nn.softmax because it changes the numerics and breaks some downstream projects that are sensitive to the change.
Are you still interested in contributing this?
@jakevdp Yes, sorry for the delay. I've been trying to find time to do it.
@jakevdp should be ready for your review.
Thanks for the speedy and patient review!