jax icon indicating copy to clipboard operation
jax copied to clipboard

Wrong values in tensors for JAX function executed in C++ as xla::HloModule

Open lucagrementieri opened this issue 2 years ago • 4 comments

I am following the guide published at https://github.com/google/jax/issues/5337 to run a JAX program in C++.

The published example (https://github.com/google/jax/tree/main/examples/jax_cpp) works as expected, but if in the exported function I define a new tensor whose size exceeds 10, then I get gibberish numbers.

I have added to prog.py the following test functions.

def return_10_almost_zeros(x, y):
    z = jnp.array([0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
    return z

def return_11_zeros(x, y):
    z = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
    return z

def return_11_ones(x, y):
    z = jnp.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
    return z

def return_11_almost_zeros(x, y):
    z = jnp.array([0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
    return z

The functions takes 2 unused arguments just to reduce the number of modifications needed in main.cc, but they can be defined with no arguments and the results are the same.

The JAX functions are exported using the prescribed command

python3 jax/tools/jax_to_hlo.py \
   --fn examples.jax_cpp.prog.return_11_almost_zeros \
   --input_shapes '[("x", "f32[2,2]"), ("y", "f32[2,2]")]' \
   --hlo_text_dest /tmp/fn_hlo.txt \
   --hlo_proto_dest /tmp/fn_hlo.pb

and then main.cc is compiled and executed using bazel run -c opt examples/jax_cpp:main.

The outputs are

// Exported return_10_almost_zeros
result = (
    f32[10] {0, 1, 0, 0, 0, 0, 0, 0, 0, 0}
)

// Exported return_11_zeros
result = (
    f32[11] {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
)

// Exported return_11_ones
result = (
    f32[11] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}
)

// Exported return_11_almost_zeros
result = (
    f32[11] {0, 1.4013e-45, 2.8026e-45, 4.2039e-45, 5.60519e-45, 7.00649e-45, 8.40779e-45, 9.80909e-45, 1.12104e-44, 1.26117e-44, 1.4013e-44}
)

As you can see from the outputs, the creation of a non-constant tensor with 11 elements generates a tensor of linearly increasing numbers equivalent to 1.4013e-45 * jnp.arange(11). The output tensor of return_11_almost_zeros is the same regardless of the tensor z defined.

@zhangqiaorjc maybe you can explain me what is happening here since you wrote the example.

lucagrementieri avatar May 05 '22 14:05 lucagrementieri

The export of return_11_almost_zeros generates the textual IR

ENTRY main.5 {
  Arg_0.1 = f32[2,2]{1,0} parameter(0)
  Arg_1.2 = f32[2,2]{1,0} parameter(1)
  constant.3 = f32[11]{0} constant({...})
  ROOT tuple.4 = (f32[11]{0}) tuple(constant.3)
}

so it's clear that the program cannot know the values in constant.3.

Reading the source code of https://github.com/google/jax/blob/main/examples/jax_cpp/main.cc I discovered that the textual representation is there only for debug purposes, so I tried the binary pb output and the results are the expected ones.

Therefore, I suppose the HLO textual representation should never be used in production environment.

lucagrementieri avatar May 23 '22 09:05 lucagrementieri

This is mostly about the code that generates the HLO. By default, the textual representation of HLO elides large constants. It can be changed not to do that: it's an option when printing the HLO as a string.

I'm not sure whether that's appropriate in this particular case or not: here the option is explicitly labeled as debug output. But perhaps it would be less confusing if it didn't elide constants. What do you think?

hawkinsp avatar May 23 '22 12:05 hawkinsp

Yes, I think that https://github.com/google/jax/blob/main/jax/tools/jax_to_ir.py#L133 should be modified to print large constants returning a fully-functional textual IR.

lucagrementieri avatar May 23 '22 12:05 lucagrementieri

The outputs are

// Exported return_10_almost_zeros
result = (
    f32[10] {0, 1, 0, 0, 0, 0, 0, 0, 0, 0}
)

// Exported return_11_zeros
result = (
    f32[11] {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
)

// Exported return_11_ones
result = (
    f32[11] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}
)

// Exported return_11_almost_zeros
result = (
    f32[11] {0, 1.4013e-45, 2.8026e-45, 4.2039e-45, 5.60519e-45, 7.00649e-45, 8.40779e-45, 9.80909e-45, 1.12104e-44, 1.26117e-44, 1.4013e-44}
)

@lucagrementieri let me ask you how to extract the only values from the tupled result tensor?

Rian-Jo avatar Apr 19 '24 09:04 Rian-Jo