jax icon indicating copy to clipboard operation
jax copied to clipboard

Logical negation ~lax.dot_general fails on GPU (but not CPU/TPU)

Open romanngg opened this issue 3 years ago • 1 comments

Example:

from jax import lax
from jax import numpy as np

a = lax.dot_general(lhs=np.array([[True],
                                  [True]]),
                    rhs=np.array([[True,  True,  True],
                                  [True,  True,  True]]),
                    dimension_numbers=(((0,), (0,)), ((), ())))
a
DeviceArray([[ True,  True,  True]], dtype=bool)
~a
DeviceArray([[ True,  True,  True]], dtype=bool)  
# Should be DeviceArray([[ False,  False,  False]], dtype=bool), which it is, on CPU or TPU

I think this might be related to https://github.com/google/jax/pull/5137

romanngg avatar Jan 14 '21 03:01 romanngg

This appears to be an XLA miscompilation.

(For tracking purposes, this is XLA bug b/177524741.)

hawkinsp avatar Jan 14 '21 14:01 hawkinsp

This is fixed and has been for a while!

hawkinsp avatar Aug 12 '22 19:08 hawkinsp