jax
jax copied to clipboard
Pallas NotImplementedError: unsupported layout change
Description
I'm trying to write a simple rnn-like for loop in pallas but get a Pallas NotImplementedError: unsupported layout change for some reason. If anyone can help fixing this that would be very helpful!
The code is:
def pallas_scan(k_ref,v_ref,r_ref,w_ref,u_ref,s_ref,out_ref,s_out_ref,T):
u = u_ref[...]
s = s_ref[...]
def loop(t,s):
k_t = pl.load(k_ref,(t,slice(None)))
v_t = pl.load(v_ref,(t,slice(None)))
r_t = pl.load(r_ref,(t,slice(None)))
w_t = pl.load(w_ref,(t,slice(None)))
kv_t = k_t * v_t
pl.store(out_ref,(t,slice(None)),(r_t * (u * kv_t + s)).squeeze(0))
s = w_t * s + kv_t
return s
s = jax.lax.fori_loop(0,T,loop,s)
pl.store(s_out_ref,(slice(None),slice(None)),s)
def rwkv_pallas(k,v,r,w,u,s):
T,D = k.shape
return pl.pallas_call(
ft.partial(jax.jit(pallas_scan,static_argnums=8), T=T),
out_shape=[jax.ShapeDtypeStruct(k.shape, k.dtype),
jax.ShapeDtypeStruct(s.shape, s.dtype)],
grid=(D,),
in_specs=[
pl.BlockSpec(lambda j: (0, j), (T, 1)),
pl.BlockSpec(lambda j: (0, j), (T, 1)),
pl.BlockSpec(lambda j: (0, j), (T, 1)),
pl.BlockSpec(lambda j: (0, j), (T, 1)),
pl.BlockSpec(lambda j: (0, j), (1, 1)),
pl.BlockSpec(lambda j: (j, j), (1, 1)),
],
out_specs=[
pl.BlockSpec(lambda j: (0, j), (T, 1)),
pl.BlockSpec(lambda j: (j, j), (1, 1)),
],
)(k,v,r,w,u,s)
And the full error message:
---------------------------------------------------------------------------
JaxStackTraceBeforeTransformation Traceback (most recent call last)
File /usr/local/lib/python3.10/runpy.py:196, in _run_module_as_main()
195 sys.argv[0] = mod_spec.origin
--> 196 return _run_code(code, main_globals, None,
197 "__main__", mod_spec)
File /usr/local/lib/python3.10/runpy.py:86, in _run_code()
79 run_globals.update(__name__ = mod_name,
80 __file__ = fname,
81 __cached__ = cached,
(...)
84 __package__ = pkg_name,
85 __spec__ = mod_spec)
---> 86 exec(code, run_globals)
87 return run_globals
File /usr/local/lib/python3.10/site-packages/ipykernel_launcher.py:17
15 from ipykernel import kernelapp as app
---> 17 app.launch_new_instance()
File /usr/local/lib/python3.10/site-packages/traitlets/config/application.py:1053, in launch_instance()
1052 app.initialize(argv)
-> 1053 app.start()
File /usr/local/lib/python3.10/site-packages/ipykernel/kernelapp.py:737, in start()
736 try:
--> 737 self.io_loop.start()
738 except KeyboardInterrupt:
File /usr/local/lib/python3.10/site-packages/tornado/platform/asyncio.py:195, in start()
194 def start(self) -> None:
--> 195 self.asyncio_loop.run_forever()
File /usr/local/lib/python3.10/asyncio/base_events.py:603, in run_forever()
602 while True:
--> 603 self._run_once()
604 if self._stopping:
File /usr/local/lib/python3.10/asyncio/base_events.py:1909, in _run_once()
1908 else:
-> 1909 handle._run()
1910 handle = None
File /usr/local/lib/python3.10/asyncio/events.py:80, in _run()
79 try:
---> 80 self._context.run(self._callback, *self._args)
81 except (SystemExit, KeyboardInterrupt):
File /usr/local/lib/python3.10/site-packages/ipykernel/kernelbase.py:524, in dispatch_queue()
523 try:
--> 524 await self.process_one()
525 except Exception:
File /usr/local/lib/python3.10/site-packages/ipykernel/kernelbase.py:513, in process_one()
512 return None
--> 513 await dispatch(*args)
File /usr/local/lib/python3.10/site-packages/ipykernel/kernelbase.py:418, in dispatch_shell()
417 if inspect.isawaitable(result):
--> 418 await result
419 except Exception:
File /usr/local/lib/python3.10/site-packages/ipykernel/kernelbase.py:758, in execute_request()
757 if inspect.isawaitable(reply_content):
--> 758 reply_content = await reply_content
760 # Flush output before sending the reply.
File /usr/local/lib/python3.10/site-packages/ipykernel/ipkernel.py:426, in do_execute()
425 if with_cell_id:
--> 426 res = shell.run_cell(
427 code,
428 store_history=store_history,
429 silent=silent,
430 cell_id=cell_id,
431 )
432 else:
File /usr/local/lib/python3.10/site-packages/ipykernel/zmqshell.py:549, in run_cell()
548 self._last_traceback = None
--> 549 return super().run_cell(*args, **kwargs)
File /usr/local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3046, in run_cell()
3045 try:
-> 3046 result = self._run_cell(
3047 raw_cell, store_history, silent, shell_futures, cell_id
3048 )
3049 finally:
File /usr/local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3101, in _run_cell()
3100 try:
-> 3101 result = runner(coro)
3102 except BaseException as e:
File /usr/local/lib/python3.10/site-packages/IPython/core/async_helpers.py:129, in _pseudo_sync_runner()
128 try:
--> 129 coro.send(None)
130 except StopIteration as exc:
File /usr/local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3306, in run_cell_async()
3303 interactivity = "none" if silent else self.ast_node_interactivity
-> 3306 has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
3307 interactivity=interactivity, compiler=compiler, result=result)
3309 self.last_execution_succeeded = not has_raised
File /usr/local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3488, in run_ast_nodes()
3487 asy = compare(code)
-> 3488 if await self.run_code(code, result, async_=asy):
3489 return True
File /usr/local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3548, in run_code()
3547 else:
-> 3548 exec(code_obj, self.user_global_ns, self.user_ns)
3549 finally:
3550 # Reset our crash handler in place
Cell In[113], line 1
----> 1 rwkv_pallas(k,v,r,w,u,s)
Cell In[110], line 3, in rwkv_pallas()
2 T,D = k.shape
----> 3 return pl.pallas_call(
4 ft.partial(jax.jit(pallas_scan,static_argnums=8), T=T),
5 out_shape=[jax.ShapeDtypeStruct(k.shape, k.dtype),
6 jax.ShapeDtypeStruct(s.shape, s.dtype)],
7 grid=(D,),
8 in_specs=[
9 pl.BlockSpec(lambda j: (0, j), (T, 1)),
10 pl.BlockSpec(lambda j: (0, j), (T, 1)),
11 pl.BlockSpec(lambda j: (0, j), (T, 1)),
12 pl.BlockSpec(lambda j: (0, j), (T, 1)),
13 pl.BlockSpec(lambda j: (0, j), (1, 1)),
14 pl.BlockSpec(lambda j: (j, j), (1, 1)),
15 ],
16 out_specs=[
17 pl.BlockSpec(lambda j: (0, j), (T, 1)),
18 pl.BlockSpec(lambda j: (j, j), (1, 1)),
19 ],
20 )(k,v,r,w,u,s)
File /usr/local/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py:383, in wrapped()
382 which_linear = (False,) * len(flat_args)
--> 383 out_flat = pallas_call_p.bind(
384 *consts, *flat_args, jaxpr=jaxpr, name=name, which_linear=which_linear,
385 in_shapes=tuple(jax.ShapeDtypeStruct(a.shape, a.dtype)
386 for a in flat_args),
387 out_shapes=tuple(flat_out_shapes), debug=debug,
388 interpret=interpret,
389 grid_mapping=grid_mapping,
390 input_output_aliases=tuple(input_output_aliases.items()),
391 **compiler_params)
392 out = tree_util.tree_unflatten(out_tree, out_flat)
JaxStackTraceBeforeTransformation: NotImplementedError: unsupported layout change for vector<1x1xf32>: VectorLayout(bitwidth=32, offsets=(0, 0), tiling=(2, 128), implicit_dim=None) -> VectorLayout(bitwidth=32, offsets=(*, 0), tiling=(8, 128), implicit_dim=None)
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
NotImplementedError Traceback (most recent call last)
Cell In[113], line 1
----> 1 rwkv_pallas(k,v,r,w,u,s)
Cell In[110], line 3, in rwkv_pallas(k, v, r, w, u, s)
1 def rwkv_pallas(k,v,r,w,u,s):
2 T,D = k.shape
----> 3 return pl.pallas_call(
4 ft.partial(jax.jit(pallas_scan,static_argnums=8), T=T),
5 out_shape=[jax.ShapeDtypeStruct(k.shape, k.dtype),
6 jax.ShapeDtypeStruct(s.shape, s.dtype)],
7 grid=(D,),
8 in_specs=[
9 pl.BlockSpec(lambda j: (0, j), (T, 1)),
10 pl.BlockSpec(lambda j: (0, j), (T, 1)),
11 pl.BlockSpec(lambda j: (0, j), (T, 1)),
12 pl.BlockSpec(lambda j: (0, j), (T, 1)),
13 pl.BlockSpec(lambda j: (0, j), (1, 1)),
14 pl.BlockSpec(lambda j: (j, j), (1, 1)),
15 ],
16 out_specs=[
17 pl.BlockSpec(lambda j: (0, j), (T, 1)),
18 pl.BlockSpec(lambda j: (j, j), (1, 1)),
19 ],
20 )(k,v,r,w,u,s)
[... skipping hidden 17 frame]
File /usr/local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/pallas_call_registration.py:87, in pallas_call_tpu_lowering_rule(ctx, jaxpr, name, which_linear, grid_mapping, input_output_aliases, in_shapes, out_shapes, debug, interpret, mosaic_params, *in_nodes, **compiler_params)
78 def _lower_fun(*args):
79 return mosaic.as_tpu_kernel(
80 mosaic_module,
81 out_avals,
(...)
85 cost_estimate=mosaic_params.get('cost_estimate', None),
86 )(*extra_args, *args)
---> 87 return mlir.lower_fun(_lower_fun, multiple_results=True)(
88 ctx, *in_nodes)
[... skipping hidden 5 frame]
File /usr/local/lib/python3.10/site-packages/jax/_src/pallas/mosaic/pallas_call_registration.py:79, in pallas_call_tpu_lowering_rule.<locals>._lower_fun(*args)
78 def _lower_fun(*args):
---> 79 return mosaic.as_tpu_kernel(
80 mosaic_module,
81 out_avals,
82 backend=ctx.module_context.backend,
83 kernel_name=name,
84 kernel_regeneration_metadata=kernel_regeneration_metadata,
85 cost_estimate=mosaic_params.get('cost_estimate', None),
86 )(*extra_args, *args)
File /usr/local/lib/python3.10/site-packages/jax/_src/tpu_custom_call.py:339, in as_tpu_kernel(module, out_type, cost_estimate, backend, device_type, kernel_name, kernel_regeneration_metadata)
335 hardware_generation = int(device_kind[len("TPU v")])
336 has_communication, has_custom_barrier = tpu.private_has_communication(
337 module.operation
338 )
--> 339 lowered_module_asm, constants = _lower_tpu_kernel(
340 module, hardware_generation, device_type=device_type
341 )
342 # TODO(amagni): Kernel name and regeneration metadata could alternatively be
343 # added as a custom attribute to the MLIR call op rather than including them
344 # in the backend_config.
345 return _lowered_as_tpu_kernel(
346 lowered_module_asm,
347 out_type,
(...)
354 cost_estimate=cost_estimate,
355 )
File /usr/local/lib/python3.10/site-packages/jax/_src/tpu_custom_call.py:278, in _lower_tpu_kernel(module, hardware_generation, device_type)
276 pipeline.run(module.operation)
277 else:
--> 278 apply_vector_layout.apply(module, hardware_generation)
279 module.operation.verify()
280 dump_mlir(module, "after apply vector layout pass")
File /usr/local/lib/python3.10/site-packages/jaxlib/mosaic/python/apply_vector_layout.py:1588, in apply(module, hardware_generation)
1586 if not isinstance(f, func.FuncOp):
1587 raise ValueError(f"Unexpected op in module body: {f.OPERATION_NAME}")
-> 1588 apply_layout_func(ctx, f)
File /usr/local/lib/python3.10/site-packages/jaxlib/mosaic/python/apply_vector_layout.py:1495, in apply_layout_func(ctx, f)
1488 """Rewrites the function according to layout annotations of its operations.
1489
1490 Args:
1491 ctx: The context used for rewriting.
1492 f: An MLIR function to be rewritten.
1493 """
1494 (entry_block,) = f.body
-> 1495 apply_layout_block(ctx, entry_block)
File /usr/local/lib/python3.10/site-packages/jaxlib/mosaic/python/apply_vector_layout.py:1501, in apply_layout_block(ctx, block)
1498 def apply_layout_block(ctx: RewriteContext, block: ir.Block):
1499 # We'll be modifying the block, so make a list of operations beforehand.
1500 for op in list(block):
-> 1501 apply_layout_op(ctx, op)
File /usr/local/lib/python3.10/site-packages/jaxlib/mosaic/python/apply_vector_layout.py:1550, in apply_layout_op(ctx, op)
1548 continue
1549 with ir.InsertionPoint(op), op.location:
-> 1550 new_v = relayout(
1551 v, src=lo, dst=li, hw_generation=ctx.hardware_generation
1552 ).result
1553 ctx.set_operand(op, idx, new_v)
1554 else:
File /usr/local/lib/python3.10/site-packages/jaxlib/mosaic/python/apply_vector_layout.py:1400, in relayout(v, src, dst, hw_generation)
1398 return assemble(vty, dst, dst_tiles)
1399 # TODO(apaszke): Implement general relayout
-> 1400 raise NotImplementedError(
1401 f"unsupported layout change for {vty}: {src} -> {dst}")
NotImplementedError: unsupported layout change for vector<1x1xf32>: VectorLayout(bitwidth=32, offsets=(0, 0), tiling=(2, 128), implicit_dim=None) -> VectorLayout(bitwidth=32, offsets=(*, 0), tiling=(8, 128), implicit_dim=None)
What jax/jaxlib version are you using?
jax v0.4.21 jaxlib v0.4.21
Which accelerator(s) are you using?
TPU
Additional system info?
Python 3.10, Kaggle Notebook
NVIDIA GPU info
No response