jax
jax copied to clipboard
Binary op compare with different element types in sharded, jitted function call
Description
I have been struggling with this bug now for a very long time and while I fail to reproduce it on a trivial example, I still want to report it and provide as much detail as possible here. Maybe someone has an idea just based on the context and error message.
I am in the process of changing a large simulation codebase to make it work with multi-GPU and sharded data arrays. Inside I have a function that does something similar to distributing particle positions (2097152, 1, 3) to a mesh (256, 256, 256) in their own cell and neighbouring ones. This is a small part of a large simulation code that is wrapped in a jit.
Without sharding this gives correct results. And removing the jit also makes the function work with sharding (but of course incredibly slow). Also calling just this part of the code jitted on sharded data works.
But the combination of sharded data and having the whole code-base jitted makes the compilation fail with the following error:
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: during context [hlo verifier]: Binary op compare with different element types: s64[] and s32[].
, for instruction %compare.162 = pred[] compare(s64[] %select.102, s32[] %multiply.67), direction=GE, metadata={op_name="jit(step_fct_with_args)/jit(main)/while/body/dynamic_update_slice" source_file="/home/fs72085/lwinkler/DISCO-DJ/src/discodj/core/painting.py" source_line=228}
I am struggling on how to debug this as in theory no 64bit object should show up in the whole simulation. (there is no jax_enable_x64). Also, I don't see any mention of Binary op compare with different element types in other jax/xla bug reports, so it doesn't seem like a common issue.
painting.py:228 contains a scan over another function, so I assume the compilation of the inner function fails.
return np.reshape(jax.lax.scan(loop_body,
mesh, # initial carry
(np.reshape(positions, (n_chunks, -1, 1, dim)), # xs[0]: positions
np.split(weight, n_chunks, axis=0) \
if weights_provided else np.ones(n_chunks, dtype=dtype), # xs[1]: weights
np.arange(n_chunks)), # xs[2]: chunk indices
)[1], -1) # return ys, reshaped from n_chunk x chunk_size to npart_tot
Based on dynamic_update_slice and the fact this only happens when one specific part of the code is included, I am pretty sure the error boils down to this line of code inside the loop_body:
return carry, (carry[tuple(split)][..., 0] * kernel).sum(axis=-1)
Still none of the variables here are of a 64bit datatype, so I am really unsure what causes this (apart from possibly the usage of tuple, which I don't know how to avoid here)
dtype = np.int32
split = [np.zeros((65536, 8, 1), dtype=dtype)] * 3
╭────── split[0] ──────╮
│ shape: (65536, 8, 1) │
│ dtype: int32 │
│ size: 2.0 MiB │
│ called in jit │
│ NamedSharding: P() │
╰──────────────────────╯
╭────── kernel ──────╮
│ shape: (65536, 8) │
│ dtype: float32 │
│ size: 2.0 MiB │
│ called in jit │
│ NamedSharding: P() │
╰────────────────────╯
╭───────────────────── carry ─────────────────────╮
│ shape: (256, 256, 256) │
│ dtype: float32 │
│ size: 64.0 MiB │
│ called in jit │
│ NamedSharding: P(None, 'gpus') │
│ axis 1 is sharded: GPU 0 contains 0:32 (of 256) │
╰─────────────────────────────────────────────────╯
Also as a slightly related question: Is there any way to read the compiler output when a hlo verifier error occurs? Because it would be interesting to see the lines before and after the part that fails. Even the output of XLA_FLAGS=--xla_dump_to= doesn't contain the line that is mentioned in the error message:
%compare.162 = pred[] compare(s64[] %select.102, s32[] %multiply.67)
Sorry that this is a bit vague, but I feel like this is might be a XLA compiler issue that only occurs in those specific circumstances, so I don't know how to break it down to a reproducible code I can share. But maybe someone has an idea how to fix this.
What jax/jaxlib version are you using?
0.4.23 (cuda11_pip) Also tested with jax==0.4.24 and jaxlib==0.4.24.dev20240206+cuda11.cudnn86 with identical output
Which accelerator(s) are you using?
GPU (8x NVIDIA A40 on 4 hosts)
Additional system info?
Python 3.11.3, Linux
NVIDIA GPU info
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08 Driver Version: 545.23.08 CUDA Version: 12.3 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA A40 Off | 00000000:41:00.0 Off | 0 |
| 0% 37C P0 74W / 300W | 4MiB / 46068MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 1 NVIDIA A40 Off | 00000000:A1:00.0 Off | 0 |
| 0% 35C P0 74W / 300W | 4MiB / 46068MiB | 4% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| No running processes found |
+---------------------------------------------------------------------------------------+
Hi - sorry it's hard to say much here without a more complete reproduction. Given that the error comes from dynamic_update_slice, I'd like to see where that operation is coming from (the line you pasted, carry[tuple(split)][..., 0], will lower to gather, not to dynamic_update_slice).
Hi @Findus23 and @jakevdp.
I ran into this same bug and was able to put together the following MRE. It's for CPUs because I don't have easy access to multiple GPUs, but maybe @Findus23 can check that the code errors-out on GPUs too?
This is with jax and jaxlib version 0.4.26. I've confirmed the error on Linux, Windows, and WSL. Also of note is that I only get the error when setting jax_enable_x64 to True. Not sure how to square that with the above error where it wasn't enabled.
# use two devices
import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=2'
# use 64-bit on CPU
import jax
jax.config.update('jax_enable_x64', True)
jax.config.update('jax_platform_name', 'cpu')
# configure a simple sharding
from jax.sharding import Mesh, PartitionSpec, NamedSharding
mesh = Mesh(jax.devices(), ('a',))
sharding = NamedSharding(mesh, PartitionSpec('a'))
def f(y):
return y - jax.lax.map(g, y)
def g(y):
return y
x = jax.numpy.ones(2)
print(f(x)) # [0. 0.]
print(jax.jit(f)(x)) # [0. 0.]
print(f(jax.device_put(x, sharding))) # [0. 0.]
print(jax.jit(f)(jax.device_put(x, sharding))) # error
The error seems the same as the above one.
Traceback (most recent call last):
File "sharding_bug_mre.py", line 25, in <module>
print(jax.jit(f)(jax.device_put(x, sharding))) # error
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: during context [hlo verifier]: Binary op compare with different element types: s64[] and s32[].
, for instruction %compare.3 = pred[] compare(s64[] %select.1, s32[] %multiply), direction=GE, metadata={op_name="jit(f)/jit(main)/while/body/dynamic_update_slice" source_file="sharding_bug_mre.py" source_line=16}
Failed after spmd-partitioning
Hi,
Thanks to @jeffgortmaker for coming up with the reproducible issue (I tested it with 2 NVIDIA A40 and can also reproduce it there). Originally I was unsure if you had not just found another bug that by chance maybe causes the same error. (After all the example doesn't even contain integer data, while the part where things break in my code does some explicit integer dtype changes) But (take everything from here with a grain of salt as I don't have much time for testing right now) maybe I misunderstood the original issue and made some mistakes when debugging and explicitly setting dtypes everywhere.
It turns out that when I run my code with single precision everywhere and no jax_enable_x64, it succeeds without this issue. But when using jax_enable_x64 one would expect explicitly making all integers np.int64 would then also work, but it doesn't. So I have the theory that my bug is mostly unrelated to my actual code and indeed exactly the thing @jeffgortmaker discovered: That combining jax_enable_x64, jit, some map and sharding breaks inside the implicit resharding.
This would also explain the great comment by @jakevdp. My code doesn't do a dynamic_update_slice, but instead the sharding code created in jit does. And maybe also why the error message points to the lax.scan/lax.map line instead of the actual code.
So I guess the workaround until the jit-compiler-bug from the example is found, is just using single precision everywhere or explicitly setting the dtype of every single np array created in the code to np.int64/np.float64 instead of jax_enable_x64 (https://github.com/google/jax/issues/8178 style).
Thanks @Findus23 for confirming on GPUs and checking on jax_enable_x64!
For what it's worth, in the above example adding lots of .astype(jnp.float64)'s didn't seem to eliminate the error for me. So for code needing 64-bit precision, I'm not sure if there's an unblocking workaround yet other than just not doing sharding in this way.
@jeffgortmaker I mistakenly thought one could still use explicit 64 types without ´jax_enable_x64` and that would just change the default. But as https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision explains, that doesn't work and therefore you are right and there doesn't seem to be any workaround for this issue (apart from not using sharding, not using 64bit anywhere or potentially not using scan/map this way)
Hi,
has anybody found any workaround ? I am sharding my arrays on multiple GPU and then applying the jax.lax.map but this seems to break when double precision is enabled.
I was working on another project for a while, but now I am back to the original project and can confirm that this issue still exists in the latest jax version and is still reproducible with the code from above (https://github.com/jax-ml/jax/issues/19691#issuecomment-2181170116).
Some additional thoughts:
- It seems like this bug is completely independent of the dtype of the array passed to f. It happens with both 32 and 64 bit floats/ints
return y - jax.lax.scan(lambda c, x: (c, x), 0, y)[1]also triggers the same issue- without the
y -, everything works in this example
As the warning relates to the HLO, I thought I would dive deeper into that to see if maybe that can explain this issue further (hoping that maybe the HLO output of tracing and lowering would be incorrect and could be fixed).
I wrote a very hacky function that can write the stableHLO output after lowering into a file and read in a modified version (by using export.export and replacing mlir_module_serialized).
The output of lowered = jax.jit(f).lower(abstract_input) with float64 input:
input: ShapeDtypeStruct(shape=(128,), dtype=float64, sharding=NamedSharding(mesh=Mesh('a': 4), spec=PartitionSpec('a',), memory_kind=unpinned_host))
module @jit_f attributes {mhlo.num_partitions = 4 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<128xf64> {mhlo.sharding = "{devices=[4]<=[4]}"}) -> (tensor<128xf64> {jax.result_info = ""}) {
%cst = stablehlo.constant dense<0.000000e+00> : tensor<f64>
%0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f64>) -> tensor<128xf64>
%c = stablehlo.constant dense<0> : tensor<i64>
%1:3 = stablehlo.while(%iterArg = %arg0, %iterArg_0 = %c, %iterArg_1 = %0) : tensor<128xf64>, tensor<i64>, tensor<128xf64>
cond {
%c_2 = stablehlo.constant dense<128> : tensor<i64>
%3 = stablehlo.compare LT, %iterArg_0, %c_2, SIGNED : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %3 : tensor<i1>
} do {
%c_2 = stablehlo.constant dense<0> : tensor<i64>
%3 = stablehlo.compare LT, %iterArg_0, %c_2, SIGNED : (tensor<i64>, tensor<i64>) -> tensor<i1>
%4 = stablehlo.convert %iterArg_0 : tensor<i64>
%c_3 = stablehlo.constant dense<128> : tensor<i64>
%5 = stablehlo.add %4, %c_3 : tensor<i64>
%6 = stablehlo.select %3, %5, %iterArg_0 : tensor<i1>, tensor<i64>
%7 = stablehlo.dynamic_slice %iterArg, %6, sizes = [1] : (tensor<128xf64>, tensor<i64>) -> tensor<1xf64>
%8 = stablehlo.reshape %7 : (tensor<1xf64>) -> tensor<f64>
%9 = func.call @None(%8) : (tensor<f64>) -> tensor<f64>
%10 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor<f64>) -> tensor<1xf64>
%11 = stablehlo.compare LT, %iterArg_0, %c_2, SIGNED : (tensor<i64>, tensor<i64>) -> tensor<i1>
%12 = stablehlo.convert %iterArg_0 : tensor<i64>
%13 = stablehlo.add %12, %c_3 : tensor<i64>
%14 = stablehlo.select %11, %13, %iterArg_0 : tensor<i1>, tensor<i64>
%15 = stablehlo.dynamic_update_slice %iterArg_1, %10, %14 : (tensor<128xf64>, tensor<1xf64>, tensor<i64>) -> tensor<128xf64>
%c_4 = stablehlo.constant dense<1> : tensor<i64>
%16 = stablehlo.add %iterArg_0, %c_4 : tensor<i64>
stablehlo.return %iterArg, %16, %15 : tensor<128xf64>, tensor<i64>, tensor<128xf64>
}
%2 = stablehlo.subtract %arg0, %1#2 : tensor<128xf64>
return %2 : tensor<128xf64>
}
func.func private @None(%arg0: tensor<f64>) -> tensor<f64> {
return %arg0 : tensor<f64>
}
}
My first guess was that the fact that num_partitions and num_replicas are i32 might be causing the issue. But switching out those types doesn't change anything.
But interestingly enough replacing
%2 = stablehlo.subtract %arg0, %1#2 : tensor<128xf64>
return %2 : tensor<128xf64>
with either return %1#2 : tensor<128xf64> or replacing stablehlo.subtract %arg0, %1#2 with stablehlo.subtract %1#2, %1#2 or anything else than an operator on both the input and the output of the map avoids the issue.
So I'm afraid that the bug is somewhere further down the line and can't be seen during the stage of stableHLO.
Okay, I think I found out something useful. At least hopeful something useful to someone who knows jax and xla far better than I do:
I was further simplying the HLO to see how much I can remove while still triggering the issue. And it seems like the map/loop is not at all needed:
import dataclasses
import os
import jax
from jax.export import export
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jaxlib.mlir.dialects import stablehlo
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4'
jax.config.update('jax_enable_x64', True)
mesh = Mesh(jax.devices(), ('a',))
sharding = NamedSharding(mesh, PartitionSpec('a'))
jitted_f = jax.jit(lambda a: a)
abstract_input = jax.ShapeDtypeStruct((128,), jax.numpy.float64, sharding=sharding)
exported = export(jitted_f)(abstract_input)
# mlir_module_fixed = exported.mlir_module_serialized
# context = jax_mlir.make_ir_context()
# out: Module = stablehlo.deserialize_portable_artifact(context, mlir_module_fixed)
# stablehlo_str = str(out)
# stablehlo_str = out.operation.get_asm(enable_debug_info=True)
stablehlo_str = """
module @jit_f attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 4 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%input_array: tensor<128xf64> {mhlo.sharding = "{devices=[4]<=[4]}"}) -> (tensor<128xf64> {jax.result_info = ""}) {
%const_scalar = stablehlo.constant dense<1> : tensor<i64>
%const_scalar2 = stablehlo.constant dense<2> : tensor<i64>
%slice = stablehlo.dynamic_slice %input_array, %const_scalar, sizes = [1] : (tensor<128xf64>, tensor<i64>) -> tensor<1xf64>
%output = stablehlo.dynamic_update_slice %input_array, %slice, %const_scalar2 : (tensor<128xf64>, tensor<1xf64>, tensor<i64>) -> tensor<128xf64>
return %output : tensor<128xf64>
}
}
""".strip()
mlir_module_fixed = stablehlo.serialize_portable_artifact_str(stablehlo_str, "1.8.8")
exported = dataclasses.replace(exported, mlir_module_serialized=mlir_module_fixed)
exported.call(abstract_input)
I can translate this back into a python example that also triggers this bug:
import os
import jax
from jax.sharding import Mesh, PartitionSpec, NamedSharding
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4'
jax.config.update('jax_enable_x64', True)
mesh = Mesh(jax.devices(), ('a',))
sharding = NamedSharding(mesh, PartitionSpec('a'))
def f(y):
slice = jax.lax.dynamic_slice(y, [0], [1])
out = jax.lax.dynamic_update_slice(y, slice, [0])
return out
jitted_f = jax.jit(f)
abstract_input = jax.ShapeDtypeStruct((128,), jax.numpy.float64, sharding=sharding)
jitted_f.lower(abstract_input).compile()
And this leads me to the fact that even the example on https://docs.jax.dev/en/latest/_autosummary/jax.lax.dynamic_update_slice.html using only lax.dynamic_update_slice (with sharding, jax_enable_x64 and jit compilation) triggers this bug:
import os
import jax
from jax.sharding import Mesh, PartitionSpec, NamedSharding
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4'
jax.config.update('jax_enable_x64', True)
mesh = Mesh(jax.devices(), ('a',))
sharding = NamedSharding(mesh, PartitionSpec('a'))
def f(arr):
y = jax.numpy.ones(3)
return jax.lax.dynamic_update_slice(arr, y, (2,))
arr = jax.device_put(jax.numpy.zeros(64), sharding)
jax.jit(f).lower(arr).compile()
I compiled jax to see if I can figure out more this way: After commenting out the HLO verifier check, (unsurprisingly) the issue is just moved to the VerifyLlvmModule function in XLA and if commenting out that too, it just fails when trying to compile the IR.
But this lead me down the route of using --xla_dump_to and I created https://github.com/openxla/xla/issues/24186 with more details (as the same happens when passing that HLO module to XLA without using jax at all)
Based on the thoughts in https://github.com/openxla/xla/issues/24186#issuecomment-2770504355 I think I found a (limited) workaround for this issue in jax (that even works without having to recompile anything).
The issue seems to be essentially that with sharding the xla compiler will transform the third argument to jax.lax.dynamic_update_slice to something that is compared with the partition-id. Now unfortunately the partition-id is always a s32 (int32)[^1] which means that whenever we pass int64 slices, this comparison is not possible and the error occurs.
But this also means that we can work around this issue without any modifications to jaxlib/xla by making sure to only pass int32 to dynamic_update_slice. So e.g. replacing jax.lax.dynamic_update_slice(arr, y, (2,)) with jax.lax.dynamic_update_slice(arr, y, (jax.numpy.array(2, dtype=jax.numpy.int32),)).
It is a bit more complicated in the scan/map[^2] example found by @jeffgortmaker as there dynamic_update_slice gets called by jax, but if you modify the jax/_src/lax/control_flow/loops.py to change the dtype it will work:
diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py
index 9a66dd0..465ff7d 100644
--- a/jax/_src/lax/control_flow/loops.py
+++ b/jax/_src/lax/control_flow/loops.py
@@ -522,6 +522,7 @@ def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear,
allow_negative_indices=False)
for xs in xss]
carry, ys = inner(unroll, carry, xs)
+ i = lax.convert_element_type(i, "int32")
yss = [slicing.dynamic_update_index_in_dim(y, upd, i, 0,
allow_negative_indices=False)
for y, upd in zip(yss, ys)]
Unfortunately I am pretty sure that this will also cause results to be quite wrong if the length of the data looped over is larger than what fits in the int32 (so 2147483647).
The more proper fix is probably to modify xla so that the partition index gets converted to a 64bit integer before the comparison.
[^1]: Technically it is a u32 that is used to look up the index in an array of s32 [^2]: map just calls scan
FYI: In https://github.com/openxla/xla/issues/24186#issuecomment-2773228073 I now found an XLA fix that I think should not have any side-effects (but I don't make any promises).
Feel free to test it in your environments (I can share jaxlib compiled with those fixes, if you don't want to compile it yourself)
@dfm Did you make any progress on this issue by any chance? I am currently also hitting this problem at a point where I can't switch to fp32 due working with badly conditioned matrices.