jax icon indicating copy to clipboard operation
jax copied to clipboard

pallas hello world not implemented for int8

Open vlad17 opened this issue 7 months ago • 0 comments

Description

from functools import partial
import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp
def add_vectors_kernel(x_ref, y_ref, o_ref):
  x, y = x_ref[...], y_ref[...]
  t = x.astype(jnp.int32) + y.astype(jnp.int32)
  o_ref[...] = t.astype(o_ref.dtype)
@partial(jax.jit,static_argnames='out_type')
def add_vectors(x: jax.Array, y: jax.Array, out_type) -> jax.Array:
  return pl.pallas_call(add_vectors_kernel,
                        out_shape=jax.ShapeDtypeStruct(x.shape, out_type)
                        )(x, y)
add_vectors(jnp.arange(8, dtype=jnp.int8), jnp.arange(8, dtype=jnp.int8), out_type=jnp.int32)

triggers

---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
[<ipython-input-45-d8659d57eed2>](https://colab.corp.google.com/drive/1APDmdaEZUyp_tu-v91vQurIaV3WblpUB#) in <module>()
     12   t = x.astype(jnp.int32) + y.astype(jnp.int32)
     13   o_ref[...] = t.astype(o_ref.dtype)
---> 14 add_vectors(jnp.arange(8, dtype=jnp.int8), jnp.arange(8, dtype=jnp.int8), out_type=jnp.int32)

XlaRuntimeError: INTERNAL: Mosaic failed to compile TPU kernel: Not implemented: Unsupported layout change for 'vector<8xi8>': "8,{0,0},(1,512),-2" -> "8,{0,0},(8,128),-2"

System info (python version, jaxlib version, accelerator, etc.)

tpu v3

jax:    0.4.31
jaxlib: 0.4.31
numpy:  1.26.3
python: 3.11.8 (stable, redacted, redacted) [Clang google3-trunk (da249cad8d398939e0c608d38d0c038954941316)]
jax.devices (8 total, 8 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1) ... TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0) TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
process_count: 1
platform: uname_result(system='Linux', node='xxxx.borgtask.google.com', release='5.10.0-smp-1101.34.0.0', version='#1 [v5.10.0-1101.34.0.0] SMP @1712273364', machine='x86_64')
[43]

vlad17 avatar Jun 30 '24 21:06 vlad17