jax icon indicating copy to clipboard operation
jax copied to clipboard

Pallas NotImplementedError: unsupported layout change

Open faresobeid opened this issue 2 years ago • 0 comments

Description

I'm trying to write a simple rnn-like for loop in pallas but get a Pallas NotImplementedError: unsupported layout change for some reason. If anyone can help fixing this that would be very helpful!

The code is:

def pallas_scan(k_ref,v_ref,r_ref,w_ref,u_ref,s_ref,out_ref,s_out_ref,T):
    u = u_ref[...]
    s = s_ref[...]
    def loop(t,s):
        k_t = pl.load(k_ref,(t,slice(None)))
        v_t = pl.load(v_ref,(t,slice(None)))
        r_t = pl.load(r_ref,(t,slice(None)))
        w_t = pl.load(w_ref,(t,slice(None)))
        kv_t = k_t * v_t
        pl.store(out_ref,(t,slice(None)),(r_t * (u * kv_t + s)).squeeze(0))
        s = w_t * s + kv_t
        return s
    s = jax.lax.fori_loop(0,T,loop,s)
    pl.store(s_out_ref,(slice(None),slice(None)),s)

def rwkv_pallas(k,v,r,w,u,s):
    T,D = k.shape
    return pl.pallas_call(
        ft.partial(jax.jit(pallas_scan,static_argnums=8), T=T),
        out_shape=[jax.ShapeDtypeStruct(k.shape, k.dtype),
                  jax.ShapeDtypeStruct(s.shape, s.dtype)],
        grid=(D,),
        in_specs=[
          pl.BlockSpec(lambda j: (0, j), (T, 1)),
          pl.BlockSpec(lambda j: (0, j), (T, 1)),
          pl.BlockSpec(lambda j: (0, j), (T, 1)),
          pl.BlockSpec(lambda j: (0, j), (T, 1)),
          pl.BlockSpec(lambda j: (0, j), (1, 1)),
          pl.BlockSpec(lambda j: (j, j), (1, 1)),
        ],
        out_specs=[
          pl.BlockSpec(lambda j: (0, j), (T, 1)),
          pl.BlockSpec(lambda j: (j, j), (1, 1)),
        ],
      )(k,v,r,w,u,s)

And the full error message:

---------------------------------------------------------------------------
JaxStackTraceBeforeTransformation         Traceback (most recent call last)
File /usr/local/lib/python3.10/runpy.py:196, in _run_module_as_main()
    195     sys.argv[0] = mod_spec.origin
--> 196 return _run_code(code, main_globals, None,
    197                  "__main__", mod_spec)

File /usr/local/lib/python3.10/runpy.py:86, in _run_code()
     79 run_globals.update(__name__ = mod_name,
     80                    __file__ = fname,
     81                    __cached__ = cached,
   (...)
     84                    __package__ = pkg_name,
     85                    __spec__ = mod_spec)
---> 86 exec(code, run_globals)
     87 return run_globals

File /usr/local/lib/python3.10/site-packages/ipykernel_launcher.py:17
     15 from ipykernel import kernelapp as app
---> 17 app.launch_new_instance()

File /usr/local/lib/python3.10/site-packages/traitlets/config/application.py:1053, in launch_instance()
   1052 app.initialize(argv)
-> 1053 app.start()

File /usr/local/lib/python3.10/site-packages/ipykernel/kernelapp.py:737, in start()
    736 try:
--> 737     self.io_loop.start()
    738 except KeyboardInterrupt:

File /usr/local/lib/python3.10/site-packages/tornado/platform/asyncio.py:195, in start()
    194 def start(self) -> None:
--> 195     self.asyncio_loop.run_forever()

File /usr/local/lib/python3.10/asyncio/base_events.py:603, in run_forever()
    602 while True:
--> 603     self._run_once()
    604     if self._stopping:

File /usr/local/lib/python3.10/asyncio/base_events.py:1909, in _run_once()
   1908     else:
-> 1909         handle._run()
   1910 handle = None

File /usr/local/lib/python3.10/asyncio/events.py:80, in _run()
     79 try:
---> 80     self._context.run(self._callback, *self._args)
     81 except (SystemExit, KeyboardInterrupt):

File /usr/local/lib/python3.10/site-packages/ipykernel/kernelbase.py:524, in dispatch_queue()
    523 try:
--> 524     await self.process_one()
    525 except Exception:

File /usr/local/lib/python3.10/site-packages/ipykernel/kernelbase.py:513, in process_one()
    512         return None
--> 513 await dispatch(*args)

File /usr/local/lib/python3.10/site-packages/ipykernel/kernelbase.py:418, in dispatch_shell()
    417     if inspect.isawaitable(result):
--> 418         await result
    419 except Exception:

File /usr/local/lib/python3.10/site-packages/ipykernel/kernelbase.py:758, in execute_request()
    757 if inspect.isawaitable(reply_content):
--> 758     reply_content = await reply_content
    760 # Flush output before sending the reply.

File /usr/local/lib/python3.10/site-packages/ipykernel/ipkernel.py:426, in do_execute()
    425 if with_cell_id:
--> 426     res = shell.run_cell(
    427         code,
    428         store_history=store_history,
    429         silent=silent,
    430         cell_id=cell_id,
    431     )
    432 else:

File /usr/local/lib/python3.10/site-packages/ipykernel/zmqshell.py:549, in run_cell()
    548 self._last_traceback = None
--> 549 return super().run_cell(*args, **kwargs)

File /usr/local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3046, in run_cell()
   3045 try:
-> 3046     result = self._run_cell(
   3047         raw_cell, store_history, silent, shell_futures, cell_id
   3048     )
   3049 finally:

File /usr/local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3101, in _run_cell()
   3100 try:
-> 3101     result = runner(coro)
   3102 except BaseException as e:

File /usr/local/lib/python3.10/site-packages/IPython/core/async_helpers.py:129, in _pseudo_sync_runner()
    128 try:
--> 129     coro.send(None)
    130 except StopIteration as exc:

File /usr/local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3306, in run_cell_async()
   3303 interactivity = "none" if silent else self.ast_node_interactivity
-> 3306 has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
   3307        interactivity=interactivity, compiler=compiler, result=result)
   3309 self.last_execution_succeeded = not has_raised

File /usr/local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3488, in run_ast_nodes()
   3487     asy = compare(code)
-> 3488 if await self.run_code(code, result, async_=asy):
   3489     return True

File /usr/local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3548, in run_code()
   3547     else:
-> 3548         exec(code_obj, self.user_global_ns, self.user_ns)
   3549 finally:
   3550     # Reset our crash handler in place

Cell In[113], line 1
----> 1 rwkv_pallas(k,v,r,w,u,s)

Cell In[110], line 3, in rwkv_pallas()
      2 T,D = k.shape
----> 3 return pl.pallas_call(
      4     ft.partial(jax.jit(pallas_scan,static_argnums=8), T=T),
      5     out_shape=[jax.ShapeDtypeStruct(k.shape, k.dtype),
      6               jax.ShapeDtypeStruct(s.shape, s.dtype)],
      7     grid=(D,),
      8     in_specs=[
      9       pl.BlockSpec(lambda j: (0, j), (T, 1)),
     10       pl.BlockSpec(lambda j: (0, j), (T, 1)),
     11       pl.BlockSpec(lambda j: (0, j), (T, 1)),
     12       pl.BlockSpec(lambda j: (0, j), (T, 1)),
     13       pl.BlockSpec(lambda j: (0, j), (1, 1)),
     14       pl.BlockSpec(lambda j: (j, j), (1, 1)),
     15     ],
     16     out_specs=[
     17       pl.BlockSpec(lambda j: (0, j), (T, 1)),
     18       pl.BlockSpec(lambda j: (j, j), (1, 1)),
     19     ],
     20   )(k,v,r,w,u,s)

File /usr/local/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py:383, in wrapped()
    382 which_linear = (False,) * len(flat_args)
--> 383 out_flat = pallas_call_p.bind(
    384     *consts, *flat_args, jaxpr=jaxpr, name=name, which_linear=which_linear,
    385     in_shapes=tuple(jax.ShapeDtypeStruct(a.shape, a.dtype)
    386                     for a in flat_args),
    387     out_shapes=tuple(flat_out_shapes), debug=debug,
    388     interpret=interpret,
    389     grid_mapping=grid_mapping,
    390     input_output_aliases=tuple(input_output_aliases.items()),
    391     **compiler_params)
    392 out = tree_util.tree_unflatten(out_tree, out_flat)

JaxStackTraceBeforeTransformation: NotImplementedError: unsupported layout change for vector<1x1xf32>: VectorLayout(bitwidth=32, offsets=(0, 0), tiling=(2, 128), implicit_dim=None) -> VectorLayout(bitwidth=32, offsets=(*, 0), tiling=(8, 128), implicit_dim=None)

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

NotImplementedError                       Traceback (most recent call last)
Cell In[113], line 1
----> 1 rwkv_pallas(k,v,r,w,u,s)

Cell In[110], line 3, in rwkv_pallas(k, v, r, w, u, s)
      1 def rwkv_pallas(k,v,r,w,u,s):
      2     T,D = k.shape
----> 3     return pl.pallas_call(
      4         ft.partial(jax.jit(pallas_scan,static_argnums=8), T=T),
      5         out_shape=[jax.ShapeDtypeStruct(k.shape, k.dtype),
      6                   jax.ShapeDtypeStruct(s.shape, s.dtype)],
      7         grid=(D,),
      8         in_specs=[
      9           pl.BlockSpec(lambda j: (0, j), (T, 1)),
     10           pl.BlockSpec(lambda j: (0, j), (T, 1)),
     11           pl.BlockSpec(lambda j: (0, j), (T, 1)),
     12           pl.BlockSpec(lambda j: (0, j), (T, 1)),
     13           pl.BlockSpec(lambda j: (0, j), (1, 1)),
     14           pl.BlockSpec(lambda j: (j, j), (1, 1)),
     15         ],
     16         out_specs=[
     17           pl.BlockSpec(lambda j: (0, j), (T, 1)),
     18           pl.BlockSpec(lambda j: (j, j), (1, 1)),
     19         ],
     20       )(k,v,r,w,u,s)

    [... skipping hidden 17 frame]

File /usr/local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/pallas_call_registration.py:87, in pallas_call_tpu_lowering_rule(ctx, jaxpr, name, which_linear, grid_mapping, input_output_aliases, in_shapes, out_shapes, debug, interpret, mosaic_params, *in_nodes, **compiler_params)
     78 def _lower_fun(*args):
     79   return mosaic.as_tpu_kernel(
     80       mosaic_module,
     81       out_avals,
   (...)
     85       cost_estimate=mosaic_params.get('cost_estimate', None),
     86   )(*extra_args, *args)
---> 87 return mlir.lower_fun(_lower_fun, multiple_results=True)(
     88     ctx, *in_nodes)

    [... skipping hidden 5 frame]

File /usr/local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/pallas_call_registration.py:79, in pallas_call_tpu_lowering_rule.<locals>._lower_fun(*args)
     78 def _lower_fun(*args):
---> 79   return mosaic.as_tpu_kernel(
     80       mosaic_module,
     81       out_avals,
     82       backend=ctx.module_context.backend,
     83       kernel_name=name,
     84       kernel_regeneration_metadata=kernel_regeneration_metadata,
     85       cost_estimate=mosaic_params.get('cost_estimate', None),
     86   )(*extra_args, *args)

File /usr/local/lib/python3.10/site-packages/jax/_src/tpu_custom_call.py:339, in as_tpu_kernel(module, out_type, cost_estimate, backend, device_type, kernel_name, kernel_regeneration_metadata)
    335 hardware_generation = int(device_kind[len("TPU v")])
    336 has_communication, has_custom_barrier = tpu.private_has_communication(
    337     module.operation
    338 )
--> 339 lowered_module_asm, constants = _lower_tpu_kernel(
    340     module, hardware_generation, device_type=device_type
    341 )
    342 # TODO(amagni): Kernel name and regeneration metadata could alternatively be
    343 # added as a custom attribute to the MLIR call op rather than including them
    344 # in the backend_config.
    345 return _lowered_as_tpu_kernel(
    346     lowered_module_asm,
    347     out_type,
   (...)
    354     cost_estimate=cost_estimate,
    355 )

File /usr/local/lib/python3.10/site-packages/jax/_src/tpu_custom_call.py:278, in _lower_tpu_kernel(module, hardware_generation, device_type)
    276   pipeline.run(module.operation)
    277 else:
--> 278   apply_vector_layout.apply(module, hardware_generation)
    279 module.operation.verify()
    280 dump_mlir(module, "after apply vector layout pass")

File /usr/local/lib/python3.10/site-packages/jaxlib/mosaic/python/apply_vector_layout.py:1588, in apply(module, hardware_generation)
   1586 if not isinstance(f, func.FuncOp):
   1587   raise ValueError(f"Unexpected op in module body: {f.OPERATION_NAME}")
-> 1588 apply_layout_func(ctx, f)

File /usr/local/lib/python3.10/site-packages/jaxlib/mosaic/python/apply_vector_layout.py:1495, in apply_layout_func(ctx, f)
   1488 """Rewrites the function according to layout annotations of its operations.
   1489 
   1490 Args:
   1491   ctx: The context used for rewriting.
   1492   f: An MLIR function to be rewritten.
   1493 """
   1494 (entry_block,) = f.body
-> 1495 apply_layout_block(ctx, entry_block)

File /usr/local/lib/python3.10/site-packages/jaxlib/mosaic/python/apply_vector_layout.py:1501, in apply_layout_block(ctx, block)
   1498 def apply_layout_block(ctx: RewriteContext, block: ir.Block):
   1499   # We'll be modifying the block, so make a list of operations beforehand.
   1500   for op in list(block):
-> 1501     apply_layout_op(ctx, op)

File /usr/local/lib/python3.10/site-packages/jaxlib/mosaic/python/apply_vector_layout.py:1550, in apply_layout_op(ctx, op)
   1548       continue
   1549     with ir.InsertionPoint(op), op.location:
-> 1550       new_v = relayout(
   1551           v, src=lo, dst=li, hw_generation=ctx.hardware_generation
   1552       ).result
   1553       ctx.set_operand(op, idx, new_v)
   1554 else:

File /usr/local/lib/python3.10/site-packages/jaxlib/mosaic/python/apply_vector_layout.py:1400, in relayout(v, src, dst, hw_generation)
   1398   return assemble(vty, dst, dst_tiles)
   1399 # TODO(apaszke): Implement general relayout
-> 1400 raise NotImplementedError(
   1401     f"unsupported layout change for {vty}: {src} -> {dst}")

NotImplementedError: unsupported layout change for vector<1x1xf32>: VectorLayout(bitwidth=32, offsets=(0, 0), tiling=(2, 128), implicit_dim=None) -> VectorLayout(bitwidth=32, offsets=(*, 0), tiling=(8, 128), implicit_dim=None)

What jax/jaxlib version are you using?

jax v0.4.21 jaxlib v0.4.21

Which accelerator(s) are you using?

TPU

Additional system info?

Python 3.10, Kaggle Notebook

NVIDIA GPU info

No response

faresobeid avatar Dec 20 '23 04:12 faresobeid