jax
jax copied to clipboard
pallas hello world not implemented for int8
Description
from functools import partial
import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp
def add_vectors_kernel(x_ref, y_ref, o_ref):
x, y = x_ref[...], y_ref[...]
t = x.astype(jnp.int32) + y.astype(jnp.int32)
o_ref[...] = t.astype(o_ref.dtype)
@partial(jax.jit,static_argnames='out_type')
def add_vectors(x: jax.Array, y: jax.Array, out_type) -> jax.Array:
return pl.pallas_call(add_vectors_kernel,
out_shape=jax.ShapeDtypeStruct(x.shape, out_type)
)(x, y)
add_vectors(jnp.arange(8, dtype=jnp.int8), jnp.arange(8, dtype=jnp.int8), out_type=jnp.int32)
triggers
---------------------------------------------------------------------------
XlaRuntimeError Traceback (most recent call last)
[<ipython-input-45-d8659d57eed2>](https://colab.corp.google.com/drive/1APDmdaEZUyp_tu-v91vQurIaV3WblpUB#) in <module>()
12 t = x.astype(jnp.int32) + y.astype(jnp.int32)
13 o_ref[...] = t.astype(o_ref.dtype)
---> 14 add_vectors(jnp.arange(8, dtype=jnp.int8), jnp.arange(8, dtype=jnp.int8), out_type=jnp.int32)
XlaRuntimeError: INTERNAL: Mosaic failed to compile TPU kernel: Not implemented: Unsupported layout change for 'vector<8xi8>': "8,{0,0},(1,512),-2" -> "8,{0,0},(8,128),-2"
System info (python version, jaxlib version, accelerator, etc.)
tpu v3
jax: 0.4.31
jaxlib: 0.4.31
numpy: 1.26.3
python: 3.11.8 (stable, redacted, redacted) [Clang google3-trunk (da249cad8d398939e0c608d38d0c038954941316)]
jax.devices (8 total, 8 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1) ... TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0) TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
process_count: 1
platform: uname_result(system='Linux', node='xxxx.borgtask.google.com', release='5.10.0-smp-1101.34.0.0', version='#1 [v5.10.0-1101.34.0.0] SMP @1712273364', machine='x86_64')
[43]