jax.nn.softmax is inconsistent under jax.jit
Description
In the following code, jax.nn.softmax returns different results under jax.jit.
import jax
import jax.numpy as jnp
x = jax.random.normal(jax.random.key(0), (3, 1, 1))
def f(x):
return jax.nn.softmax(x, axis=0)
print(f(x))
print(jax.jit(f)(x))
[[[0.75250584]]
[[0.0755428 ]]
[[0.17195134]]]
[[[1.]]
[[1.]]
[[1.]]]
Using a "numerically safe" version of softmax based on log_softmax solves the issue.
def f(x):
return jnp.exp(jax.nn.log_softmax(x, axis=0))
print(f(x))
print(jax.jit(f)(x))
[[[0.75250584]]
[[0.07554279]]
[[0.17195135]]]
[[[0.75250584]]
[[0.07554279]]
[[0.17195135]]]
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.26
jaxlib: 0.4.26
numpy: 1.26.2
python: 3.9.18 | packaged by conda-forge | (main, Aug 30 2023, 03:49:32) [GCC 12.3.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='thinkpad', release='6.5.0-0.deb12.4-amd64', version='#1 SMP PREEMPT_DYNAMIC Debian 6.5.10-1~bpo12+1 (2023-11-23)', machine='x86_64')
Note that printing within f weirdly fixes the issue.
import jax
import jax.numpy as jnp
x = jax.random.normal(jax.random.key(0), (3, 1, 1))
def f(x):
jax.debug.print('{}', x[0, 0])
return jax.nn.softmax(x, axis=0)
print(f(x))
print(jax.jit(f)(x))
[1.8160863]
[[[0.75250584]]
[[0.0755428 ]]
[[0.17195134]]]
[1.8160863]
[[[0.75250584]]
[[0.0755428 ]]
[[0.17195134]]]
Thanks for the report – it looks like this is a CPU-only issue; I get the expected result when running on GPU.
Also seems like it was introduced in JAX v0.4.22; JAX v0.4.21 and earlier returns the expected results for the handful of versions I've tried.
it looks like this is a CPU-only issue
Oh I should have checked. Also, it could be related to this comment:
https://github.com/google/jax/blob/e498bca2233fa427153f921f66aebd4df488aa9a/jax/_src/nn/functions.py#L536-L538
It looks like this bug appears only when the XLA softmax rewriter is enabled, so it's likely related to this XLA change: https://github.com/openxla/xla/pull/7540
Determined this by running the following on JAX v0.4.21 and 0.4.22:
import jax
import jax.numpy as jnp
print(jax.__version__)
x = jax.random.normal(jax.random.key(0), (3, 1, 1))
def f(x):
return jax.nn.softmax(x, axis=0)
print(f(x))
print(jax.jit(f)(x))
print(jax.jit(f).lower(x).compile().as_text())
Outputs:
0.4.22
[[[0.75250584]]
[[0.0755428 ]]
[[0.17195134]]]
[[[1.]]
[[1.]]
[[1.]]]
HloModule jit_f, entry_computation_layout={(f32[3,1,1]{2,1,0})->f32[3,1,1]{2,1,0}}, allow_spmd_sharding_propagation_to_output={true}
ENTRY %main.25 (Arg_0.1: f32[3,1,1]) -> f32[3,1,1] {
%Arg_0.1 = f32[3,1,1]{2,1,0} parameter(0), sharding={replicated}
ROOT %custom-call = f32[3,1,1]{2,1,0} custom-call(f32[3,1,1]{2,1,0} %Arg_0.1), custom_call_target="__onednn$softmax", metadata={op_name="jit(f)/jit(main)/div" source_file="<ipython-input-4-b7a605546166>" source_line=10}
}
0.4.21
[[[0.75250584]]
[[0.0755428 ]]
[[0.17195134]]]
[[[0.75250584]]
[[0.0755428 ]]
[[0.17195134]]]
HloModule jit_f, entry_computation_layout={(f32[3,1,1]{2,1,0})->f32[3,1,1]{2,1,0}}, allow_spmd_sharding_propagation_to_output={true}
%region_0.4 (Arg_0.5: f32[], Arg_1.6: f32[]) -> f32[] {
%Arg_0.5 = f32[] parameter(0)
%Arg_1.6 = f32[] parameter(1)
ROOT %maximum.7 = f32[] maximum(f32[] %Arg_0.5, f32[] %Arg_1.6), metadata={op_name="jit(f)/jit(main)/reduce_max[axes=(0,)]" source_file="<ipython-input-1-0271a21b44ca>" source_line=7}
}
%region_1.15 (Arg_0.16: f32[], Arg_1.17: f32[]) -> f32[] {
%Arg_0.16 = f32[] parameter(0)
%Arg_1.17 = f32[] parameter(1)
ROOT %add.18 = f32[] add(f32[] %Arg_0.16, f32[] %Arg_1.17), metadata={op_name="jit(f)/jit(main)/reduce_sum[axes=(0,)]" source_file="<ipython-input-1-0271a21b44ca>" source_line=7}
}
%fused_computation (param_0: f32[3,1,1], param_1.2: f32[1,1]) -> f32[3,1,1] {
%param_0 = f32[3,1,1]{2,1,0} parameter(0)
%param_1.2 = f32[1,1]{1,0} parameter(1)
%bitcast.2 = f32[] bitcast(f32[1,1]{1,0} %param_1.2), metadata={op_name="jit(f)/jit(main)/reduce_sum[axes=(0,)]" source_file="<ipython-input-1-0271a21b44ca>" source_line=7}
%broadcast.2 = f32[3,1,1]{2,1,0} broadcast(f32[] %bitcast.2), dimensions={}, metadata={op_name="jit(f)/jit(main)/div" source_file="<ipython-input-1-0271a21b44ca>" source_line=7}
ROOT %divide.0 = f32[3,1,1]{2,1,0} divide(f32[3,1,1]{2,1,0} %param_0, f32[3,1,1]{2,1,0} %broadcast.2), metadata={op_name="jit(f)/jit(main)/div" source_file="<ipython-input-1-0271a21b44ca>" source_line=7}
}
%fused_computation.1 (param_0.2: f32[3,1,1], param_1.5: f32[1,1]) -> f32[3,1,1] {
%param_0.2 = f32[3,1,1]{2,1,0} parameter(0)
%param_1.5 = f32[1,1]{1,0} parameter(1)
%bitcast.3 = f32[] bitcast(f32[1,1]{1,0} %param_1.5), metadata={op_name="jit(f)/jit(main)/reduce_max[axes=(0,)]" source_file="<ipython-input-1-0271a21b44ca>" source_line=7}
%broadcast.3 = f32[3,1,1]{2,1,0} broadcast(f32[] %bitcast.3), dimensions={}, metadata={op_name="jit(f)/jit(main)/sub" source_file="<ipython-input-1-0271a21b44ca>" source_line=7}
%subtract.0 = f32[3,1,1]{2,1,0} subtract(f32[3,1,1]{2,1,0} %param_0.2, f32[3,1,1]{2,1,0} %broadcast.3), metadata={op_name="jit(f)/jit(main)/sub" source_file="<ipython-input-1-0271a21b44ca>" source_line=7}
ROOT %exponential.0 = f32[3,1,1]{2,1,0} exponential(f32[3,1,1]{2,1,0} %subtract.0), metadata={op_name="jit(f)/jit(main)/exp" source_file="<ipython-input-1-0271a21b44ca>" source_line=7}
}
ENTRY %main.25 (Arg_0.1: f32[3,1,1]) -> f32[3,1,1] {
%Arg_0.1 = f32[3,1,1]{2,1,0} parameter(0), sharding={replicated}
%constant.3 = f32[] constant(-inf)
%reduce.8 = f32[1,1]{1,0} reduce(f32[3,1,1]{2,1,0} %Arg_0.1, f32[] %constant.3), dimensions={0}, to_apply=%region_0.4, metadata={op_name="jit(f)/jit(main)/reduce_max[axes=(0,)]" source_file="<ipython-input-1-0271a21b44ca>" source_line=7}
%fusion.1 = f32[3,1,1]{2,1,0} fusion(f32[3,1,1]{2,1,0} %Arg_0.1, f32[1,1]{1,0} %reduce.8), kind=kLoop, calls=%fused_computation.1, metadata={op_name="jit(f)/jit(main)/exp" source_file="<ipython-input-1-0271a21b44ca>" source_line=7}
%constant.2 = f32[] constant(0)
%reduce.19 = f32[1,1]{1,0} reduce(f32[3,1,1]{2,1,0} %fusion.1, f32[] %constant.2), dimensions={0}, to_apply=%region_1.15, metadata={op_name="jit(f)/jit(main)/reduce_sum[axes=(0,)]" source_file="<ipython-input-1-0271a21b44ca>" source_line=7}
ROOT %fusion = f32[3,1,1]{2,1,0} fusion(f32[3,1,1]{2,1,0} %fusion.1, f32[1,1]{1,0} %reduce.19), kind=kLoop, calls=%fused_computation, metadata={op_name="jit(f)/jit(main)/div" source_file="<ipython-input-1-0271a21b44ca>" source_line=7}
}
And it looks like this is "fixed" on JAX's main branch, probably because #20643 changed the raw HLO so that the XLA rewrite logic no longer recognizes it as as softmax. Yeesh, what a mess.
Whoa that is a mess. Just looking at the onednn-softmax, it only supports axis=-1 (it's hard-coded).
@jakevdp I think this issue can be closed as solved as 0.4.30 gives same results from the repro code on CPU:
0.4.30
[[[0.75250584]]
[[0.0755428 ]]
[[0.17195134]]] {CpuDevice(id=0)}
[[[0.75250584]]
[[0.07554279]]
[[0.17195134]]] {CpuDevice(id=0)}