jax icon indicating copy to clipboard operation
jax copied to clipboard

jax.nn.softmax is inconsistent under jax.jit

Open francois-rozet opened this issue 1 year ago • 7 comments

Description

In the following code, jax.nn.softmax returns different results under jax.jit.

import jax
import jax.numpy as jnp

x = jax.random.normal(jax.random.key(0), (3, 1, 1))

def f(x):
    return jax.nn.softmax(x, axis=0)

print(f(x))
print(jax.jit(f)(x))
[[[0.75250584]]

 [[0.0755428 ]]

 [[0.17195134]]]
[[[1.]]

 [[1.]]

 [[1.]]]

Using a "numerically safe" version of softmax based on log_softmax solves the issue.

def f(x):
    return jnp.exp(jax.nn.log_softmax(x, axis=0))

print(f(x))
print(jax.jit(f)(x))
[[[0.75250584]]

 [[0.07554279]]

 [[0.17195135]]]
[[[0.75250584]]

 [[0.07554279]]

 [[0.17195135]]]

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.2
python: 3.9.18 | packaged by conda-forge | (main, Aug 30 2023, 03:49:32)  [GCC 12.3.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='thinkpad', release='6.5.0-0.deb12.4-amd64', version='#1 SMP PREEMPT_DYNAMIC Debian 6.5.10-1~bpo12+1 (2023-11-23)', machine='x86_64')

francois-rozet avatar Apr 21 '24 13:04 francois-rozet

Note that printing within f weirdly fixes the issue.

import jax
import jax.numpy as jnp

x = jax.random.normal(jax.random.key(0), (3, 1, 1))

def f(x):
    jax.debug.print('{}', x[0, 0])
    return jax.nn.softmax(x, axis=0)

print(f(x))
print(jax.jit(f)(x))
[1.8160863]
[[[0.75250584]]

 [[0.0755428 ]]

 [[0.17195134]]]
[1.8160863]
[[[0.75250584]]

 [[0.0755428 ]]

 [[0.17195134]]]

francois-rozet avatar Apr 21 '24 14:04 francois-rozet

Thanks for the report – it looks like this is a CPU-only issue; I get the expected result when running on GPU.

jakevdp avatar Apr 21 '24 14:04 jakevdp

Also seems like it was introduced in JAX v0.4.22; JAX v0.4.21 and earlier returns the expected results for the handful of versions I've tried.

jakevdp avatar Apr 21 '24 14:04 jakevdp

it looks like this is a CPU-only issue

Oh I should have checked. Also, it could be related to this comment:

https://github.com/google/jax/blob/e498bca2233fa427153f921f66aebd4df488aa9a/jax/_src/nn/functions.py#L536-L538

francois-rozet avatar Apr 21 '24 14:04 francois-rozet

It looks like this bug appears only when the XLA softmax rewriter is enabled, so it's likely related to this XLA change: https://github.com/openxla/xla/pull/7540

Determined this by running the following on JAX v0.4.21 and 0.4.22:

import jax
import jax.numpy as jnp
print(jax.__version__)

x = jax.random.normal(jax.random.key(0), (3, 1, 1))

def f(x):
    return jax.nn.softmax(x, axis=0)

print(f(x))
print(jax.jit(f)(x))

print(jax.jit(f).lower(x).compile().as_text())

Outputs:

0.4.22
[[[0.75250584]]

 [[0.0755428 ]]

 [[0.17195134]]]
[[[1.]]

 [[1.]]

 [[1.]]]
HloModule jit_f, entry_computation_layout={(f32[3,1,1]{2,1,0})->f32[3,1,1]{2,1,0}}, allow_spmd_sharding_propagation_to_output={true}

ENTRY %main.25 (Arg_0.1: f32[3,1,1]) -> f32[3,1,1] {
  %Arg_0.1 = f32[3,1,1]{2,1,0} parameter(0), sharding={replicated}
  ROOT %custom-call = f32[3,1,1]{2,1,0} custom-call(f32[3,1,1]{2,1,0} %Arg_0.1), custom_call_target="__onednn$softmax", metadata={op_name="jit(f)/jit(main)/div" source_file="<ipython-input-4-b7a605546166>" source_line=10}
}
0.4.21
[[[0.75250584]]

 [[0.0755428 ]]

 [[0.17195134]]]
[[[0.75250584]]

 [[0.0755428 ]]

 [[0.17195134]]]
HloModule jit_f, entry_computation_layout={(f32[3,1,1]{2,1,0})->f32[3,1,1]{2,1,0}}, allow_spmd_sharding_propagation_to_output={true}

%region_0.4 (Arg_0.5: f32[], Arg_1.6: f32[]) -> f32[] {
  %Arg_0.5 = f32[] parameter(0)
  %Arg_1.6 = f32[] parameter(1)
  ROOT %maximum.7 = f32[] maximum(f32[] %Arg_0.5, f32[] %Arg_1.6), metadata={op_name="jit(f)/jit(main)/reduce_max[axes=(0,)]" source_file="<ipython-input-1-0271a21b44ca>" source_line=7}
}

%region_1.15 (Arg_0.16: f32[], Arg_1.17: f32[]) -> f32[] {
  %Arg_0.16 = f32[] parameter(0)
  %Arg_1.17 = f32[] parameter(1)
  ROOT %add.18 = f32[] add(f32[] %Arg_0.16, f32[] %Arg_1.17), metadata={op_name="jit(f)/jit(main)/reduce_sum[axes=(0,)]" source_file="<ipython-input-1-0271a21b44ca>" source_line=7}
}

%fused_computation (param_0: f32[3,1,1], param_1.2: f32[1,1]) -> f32[3,1,1] {
  %param_0 = f32[3,1,1]{2,1,0} parameter(0)
  %param_1.2 = f32[1,1]{1,0} parameter(1)
  %bitcast.2 = f32[] bitcast(f32[1,1]{1,0} %param_1.2), metadata={op_name="jit(f)/jit(main)/reduce_sum[axes=(0,)]" source_file="<ipython-input-1-0271a21b44ca>" source_line=7}
  %broadcast.2 = f32[3,1,1]{2,1,0} broadcast(f32[] %bitcast.2), dimensions={}, metadata={op_name="jit(f)/jit(main)/div" source_file="<ipython-input-1-0271a21b44ca>" source_line=7}
  ROOT %divide.0 = f32[3,1,1]{2,1,0} divide(f32[3,1,1]{2,1,0} %param_0, f32[3,1,1]{2,1,0} %broadcast.2), metadata={op_name="jit(f)/jit(main)/div" source_file="<ipython-input-1-0271a21b44ca>" source_line=7}
}

%fused_computation.1 (param_0.2: f32[3,1,1], param_1.5: f32[1,1]) -> f32[3,1,1] {
  %param_0.2 = f32[3,1,1]{2,1,0} parameter(0)
  %param_1.5 = f32[1,1]{1,0} parameter(1)
  %bitcast.3 = f32[] bitcast(f32[1,1]{1,0} %param_1.5), metadata={op_name="jit(f)/jit(main)/reduce_max[axes=(0,)]" source_file="<ipython-input-1-0271a21b44ca>" source_line=7}
  %broadcast.3 = f32[3,1,1]{2,1,0} broadcast(f32[] %bitcast.3), dimensions={}, metadata={op_name="jit(f)/jit(main)/sub" source_file="<ipython-input-1-0271a21b44ca>" source_line=7}
  %subtract.0 = f32[3,1,1]{2,1,0} subtract(f32[3,1,1]{2,1,0} %param_0.2, f32[3,1,1]{2,1,0} %broadcast.3), metadata={op_name="jit(f)/jit(main)/sub" source_file="<ipython-input-1-0271a21b44ca>" source_line=7}
  ROOT %exponential.0 = f32[3,1,1]{2,1,0} exponential(f32[3,1,1]{2,1,0} %subtract.0), metadata={op_name="jit(f)/jit(main)/exp" source_file="<ipython-input-1-0271a21b44ca>" source_line=7}
}

ENTRY %main.25 (Arg_0.1: f32[3,1,1]) -> f32[3,1,1] {
  %Arg_0.1 = f32[3,1,1]{2,1,0} parameter(0), sharding={replicated}
  %constant.3 = f32[] constant(-inf)
  %reduce.8 = f32[1,1]{1,0} reduce(f32[3,1,1]{2,1,0} %Arg_0.1, f32[] %constant.3), dimensions={0}, to_apply=%region_0.4, metadata={op_name="jit(f)/jit(main)/reduce_max[axes=(0,)]" source_file="<ipython-input-1-0271a21b44ca>" source_line=7}
  %fusion.1 = f32[3,1,1]{2,1,0} fusion(f32[3,1,1]{2,1,0} %Arg_0.1, f32[1,1]{1,0} %reduce.8), kind=kLoop, calls=%fused_computation.1, metadata={op_name="jit(f)/jit(main)/exp" source_file="<ipython-input-1-0271a21b44ca>" source_line=7}
  %constant.2 = f32[] constant(0)
  %reduce.19 = f32[1,1]{1,0} reduce(f32[3,1,1]{2,1,0} %fusion.1, f32[] %constant.2), dimensions={0}, to_apply=%region_1.15, metadata={op_name="jit(f)/jit(main)/reduce_sum[axes=(0,)]" source_file="<ipython-input-1-0271a21b44ca>" source_line=7}
  ROOT %fusion = f32[3,1,1]{2,1,0} fusion(f32[3,1,1]{2,1,0} %fusion.1, f32[1,1]{1,0} %reduce.19), kind=kLoop, calls=%fused_computation, metadata={op_name="jit(f)/jit(main)/div" source_file="<ipython-input-1-0271a21b44ca>" source_line=7}
}

jakevdp avatar Apr 21 '24 14:04 jakevdp

And it looks like this is "fixed" on JAX's main branch, probably because #20643 changed the raw HLO so that the XLA rewrite logic no longer recognizes it as as softmax. Yeesh, what a mess.

jakevdp avatar Apr 21 '24 14:04 jakevdp

Whoa that is a mess. Just looking at the onednn-softmax, it only supports axis=-1 (it's hard-coded).

NeilGirdhar avatar Apr 22 '24 00:04 NeilGirdhar

@jakevdp I think this issue can be closed as solved as 0.4.30 gives same results from the repro code on CPU:

0.4.30
[[[0.75250584]]
 [[0.0755428 ]]
 [[0.17195134]]] {CpuDevice(id=0)}

[[[0.75250584]]
 [[0.07554279]]
 [[0.17195134]]] {CpuDevice(id=0)}

vfdev-5 avatar Jul 09 '24 07:07 vfdev-5