jax
jax copied to clipboard
Wrong values in tensors for JAX function executed in C++ as xla::HloModule
I am following the guide published at https://github.com/google/jax/issues/5337 to run a JAX program in C++.
The published example (https://github.com/google/jax/tree/main/examples/jax_cpp) works as expected, but if in the exported function I define a new tensor whose size exceeds 10, then I get gibberish numbers.
I have added to prog.py
the following test functions.
def return_10_almost_zeros(x, y):
z = jnp.array([0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
return z
def return_11_zeros(x, y):
z = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
return z
def return_11_ones(x, y):
z = jnp.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
return z
def return_11_almost_zeros(x, y):
z = jnp.array([0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
return z
The functions takes 2 unused arguments just to reduce the number of modifications needed in main.cc
, but they can be defined with no arguments and the results are the same.
The JAX functions are exported using the prescribed command
python3 jax/tools/jax_to_hlo.py \
--fn examples.jax_cpp.prog.return_11_almost_zeros \
--input_shapes '[("x", "f32[2,2]"), ("y", "f32[2,2]")]' \
--hlo_text_dest /tmp/fn_hlo.txt \
--hlo_proto_dest /tmp/fn_hlo.pb
and then main.cc
is compiled and executed using bazel run -c opt examples/jax_cpp:main
.
The outputs are
// Exported return_10_almost_zeros
result = (
f32[10] {0, 1, 0, 0, 0, 0, 0, 0, 0, 0}
)
// Exported return_11_zeros
result = (
f32[11] {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
)
// Exported return_11_ones
result = (
f32[11] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}
)
// Exported return_11_almost_zeros
result = (
f32[11] {0, 1.4013e-45, 2.8026e-45, 4.2039e-45, 5.60519e-45, 7.00649e-45, 8.40779e-45, 9.80909e-45, 1.12104e-44, 1.26117e-44, 1.4013e-44}
)
As you can see from the outputs, the creation of a non-constant tensor with 11 elements generates a tensor of linearly increasing numbers equivalent to 1.4013e-45 * jnp.arange(11)
. The output tensor of return_11_almost_zeros
is the same regardless of the tensor z
defined.
@zhangqiaorjc maybe you can explain me what is happening here since you wrote the example.
The export of return_11_almost_zeros
generates the textual IR
ENTRY main.5 {
Arg_0.1 = f32[2,2]{1,0} parameter(0)
Arg_1.2 = f32[2,2]{1,0} parameter(1)
constant.3 = f32[11]{0} constant({...})
ROOT tuple.4 = (f32[11]{0}) tuple(constant.3)
}
so it's clear that the program cannot know the values in constant.3
.
Reading the source code of https://github.com/google/jax/blob/main/examples/jax_cpp/main.cc
I discovered that the textual representation is there only for debug purposes, so I tried the binary pb
output and the results are the expected ones.
Therefore, I suppose the HLO textual representation should never be used in production environment.
This is mostly about the code that generates the HLO. By default, the textual representation of HLO elides large constants. It can be changed not to do that: it's an option when printing the HLO as a string.
I'm not sure whether that's appropriate in this particular case or not: here the option is explicitly labeled as debug output. But perhaps it would be less confusing if it didn't elide constants. What do you think?
Yes, I think that https://github.com/google/jax/blob/main/jax/tools/jax_to_ir.py#L133 should be modified to print large constants returning a fully-functional textual IR.
The outputs are
// Exported return_10_almost_zeros result = ( f32[10] {0, 1, 0, 0, 0, 0, 0, 0, 0, 0} ) // Exported return_11_zeros result = ( f32[11] {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} ) // Exported return_11_ones result = ( f32[11] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1} ) // Exported return_11_almost_zeros result = ( f32[11] {0, 1.4013e-45, 2.8026e-45, 4.2039e-45, 5.60519e-45, 7.00649e-45, 8.40779e-45, 9.80909e-45, 1.12104e-44, 1.26117e-44, 1.4013e-44} )
@lucagrementieri let me ask you how to extract the only values from the tupled result tensor?