jax
jax copied to clipboard
Logical negation ~lax.dot_general fails on GPU (but not CPU/TPU)
Example:
from jax import lax
from jax import numpy as np
a = lax.dot_general(lhs=np.array([[True],
[True]]),
rhs=np.array([[True, True, True],
[True, True, True]]),
dimension_numbers=(((0,), (0,)), ((), ())))
a
DeviceArray([[ True, True, True]], dtype=bool)
~a
DeviceArray([[ True, True, True]], dtype=bool)
# Should be DeviceArray([[ False, False, False]], dtype=bool), which it is, on CPU or TPU
I think this might be related to https://github.com/google/jax/pull/5137
This appears to be an XLA miscompilation.
(For tracking purposes, this is XLA bug b/177524741.)
This is fixed and has been for a while!