jax icon indicating copy to clipboard operation
jax copied to clipboard

Not implemented: Non-trivial layouts unsupported

Open vanbasten23 opened this issue 4 months ago • 1 comments

Description

Hi. I am extending the Pallas paged attention kernel. The case is a MQA. When I run my kernel, I encountered the following error which suggests it is an internal error and I should report here.

root@t1v-n-f3643994-w-0:/workspaces/persist# rm -rf /workspaces/persist/tpu_logs && LIBTPU_INIT_ARGS="--xla_tpu_dump_logs_to_dir=/workspaces/persist/tpu_logs"  python pytorch/xla/test/test_pallas.py -v -k  PallasTest.test_extended_paged_attention_v1_multiple_queries 2>&1 | tee ~/out.txt
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
test_extended_paged_attention_v1_multiple_queries (__main__.PallasTest) ... The test test_extended_paged_attention_multiple_queries begins with query_len=4
ERROR

======================================================================
ERROR: test_extended_paged_attention_v1_multiple_queries (__main__.PallasTest)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/site-packages/jax/_src/compiler.py", line 266, in backend_compile
    return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Mosaic failed to compile TPU kernel: Not implemented: Non-trivial layouts unsupported

at location: loc("/repeat"(callsite("_flash_attention"("/workspaces/persist/pytorch/xla/torch_xla/experimental/pallas_kernels/extended_paged_attention_kernel1.py":180:0) at callsite("paged_flash_attention_kernel"("/workspaces/persist/pytorch/xla/torch_xla/experimental/pallas_kernels/extended_paged_attention_kernel1.py":335:0) at callsite("paged_attention"("/workspaces/persist/pytorch/xla/torch_xla/experimental/pallas_kernels/extended_paged_attention_kernel1.py":558:0) at callsite("test_extended_paged_attention_v1_multiple_queries"("/workspaces/persist/pytorch/xla/test/test_pallas.py":773:0) at "<module>"("/workspaces/persist/pytorch/xla/test/test_pallas.py":1669:0)))))))

The MLIR operation involved:
  %186 = "tpu.repeat"(%177) <{dimension = 1 : i32, times = 1 : i32}> {in_layout = [#tpu.vpad<"32,{0,0},(4,128)">], out_layout = [#tpu.vpad<"32,{0,0},(4,128)">]} : (vector<4x128xf32>) -> vector<4x128xf32>

Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke


The above exception was the direct cause of the following exception:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/workspaces/persist/pytorch/xla/test/test_pallas.py", line 773, in test_extended_paged_attention_v1_multiple_queries
    out = jax_extended_paged_attention1(
  File "/workspaces/persist/pytorch/xla/test/test_pallas.py", line 1669, in <module>
    test = unittest.main()
  File "/workspaces/persist/pytorch/xla/test/test_pallas.py", line 773, in test_extended_paged_attention_v1_multiple_queries
    out = jax_extended_paged_attention1(
  File "/workspaces/persist/pytorch/xla/torch_xla/experimental/pallas_kernels/extended_paged_attention_kernel1.py", line 558, in paged_attention
    out = pl.pallas_call(
  File "/workspaces/persist/pytorch/xla/torch_xla/experimental/pallas_kernels/extended_paged_attention_kernel1.py", line 335, in paged_flash_attention_kernel
    out_q_head_idx = _flash_attention(
  File "/workspaces/persist/pytorch/xla/torch_xla/experimental/pallas_kernels/extended_paged_attention_kernel1.py", line 180, in _flash_attention
    acc_scratch_ref[:] *= pltpu.repeat(acc_scale, acc_scale_repeats, axis=1)
jax._src.pallas.mosaic.error_handling.MosaicError: INTERNAL: Mosaic failed to compile TPU kernel: Not implemented: Non-trivial layouts unsupported

The MLIR operation involved:
  %186 = "tpu.repeat"(%177) <{dimension = 1 : i32, times = 1 : i32}> {in_layout = [#tpu.vpad<"32,{0,0},(4,128)">], out_layout = [#tpu.vpad<"32,{0,0},(4,128)">]} : (vector<4x128xf32>) -> vector<4x128xf32>

Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke


----------------------------------------------------------------------
Ran 1 test in 0.592s

FAILED (errors=1)

Here is my pallas kernel and the test code that calls the kernel.

Please let me know if you need more info.

cc @miladm @WoosukKwon

System info (python version, jaxlib version, accelerator, etc.)

>>> import jax; jax.print_environment_info()
jax:    0.4.33.dev20240913
jaxlib: 0.4.33.dev20240913
numpy:  2.1.1
python: 3.10.15 (main, Sep 27 2024, 06:06:16) [GCC 10.2.1 20210110]
jax.devices (8 total, 8 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0) ... TpuDevice(id=6, process_index=0, coords=(2,1,0), core_on_chip=0) TpuDevice(id=7, process_index=0, coords=(3,1,0), core_on_chip=0)]
process_count: 1
platform: uname_result(system='Linux', node='t1v-n-f3643994-w-0', release='5.19.0-1030-gcp', version='#32~22.04.1-Ubuntu SMP Thu Jul 13 09:36:23 UTC 2023', machine='x86_64')

vanbasten23 avatar Oct 20 '24 21:10 vanbasten23