jax icon indicating copy to clipboard operation
jax copied to clipboard

[Pallas TPU] Error when negating a boolean value

Open ayaka14732 opened this issue 1 year ago • 1 comments

Description

import functools
import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp

@functools.partial(
    pl.pallas_call,
    out_shape=jax.ShapeDtypeStruct((2,), jnp.bool_),
)
def kernel(x_ref, o_ref):
    o_ref[...] = jnp.logical_not(x_ref[...])

def main() -> None:
    x = jnp.array([False, True], dtype=jnp.bool_)
    out = kernel(x)
    print(out)

if __name__ == '__main__':
    main()

Error:

Traceback (most recent call last):
  File "/home/ayx/jax/2.py", line 19, in <module>
    main()
    ~~~~^^
  File "/home/ayx/jax/2.py", line 15, in main
    out = kernel(x)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Mosaic failed to compile TPU kernel: Can't change bitwidth during a relayout

at location: loc(unknown)


Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke

--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.35.dev20241010+3bd8ca480
jaxlib: 0.4.34
numpy:  2.1.2
python: 3.13.0 (main, Oct  8 2024, 01:04:00) [Clang 18.1.8 ]
device info: TPU v5 lite-8, 8 local devices"
process_count: 1
platform: uname_result(system='Linux', node='t1v-n-ab2ce832-w-0', release='5.19.0-1027-gcp', version='#29~22.04.1-Ubuntu SMP Thu Jun 22 05:13:17 UTC 2023', machine='x86_64')

ayaka14732 avatar Oct 11 '24 00:10 ayaka14732

This is caused by i1 splat not being supported by Mosaic. Logical not is lowered to xor(ones, x), but mosaic fails to create the array of ones properly for a boolean.

We should fix it there and not in Pallas.

justinjfu avatar Oct 17 '24 18:10 justinjfu

Reassigned internally.

ayaka14732 avatar Oct 29 '24 00:10 ayaka14732

Caused by https://github.com/jax-ml/jax/issues/24464

ayaka14732 avatar Oct 29 '24 00:10 ayaka14732