[BUG] libc++abi crash when using recurrent layer and transformer
Describe the bug
(libc++abi: terminating due to uncaught exception of type std::runtime_error:
[compile] Too many inputs/outputs fused in the Metal Compiled primitive which
exhausted the available argument buffers for the kernel. Please file an issue with
the function that results in this error. The name of the kernel is
'Nf4MultiplyABOf4AddEFPf4AddOGQf4AddPHRf4AddQISf4MultiplyRJTf4AddDSUf4MultiplyCTVf4SquareRWf4MultiplyVLXf4AddKWYf4SqrtXZf4AddYMAAf4DivideUZABf4SubtractNAA_VVVVVVVVVVVVV_f4f4f4f4f4f4f4f4f4f4f4f4f4_11160318154034397263_strided_dynamic')
In:
def __call__(self, x):
L = x.shape[1]
mask = nn.MultiHeadAttention.create_additive_causal_mask(L)
x = self.embed(x)
x = x + self.pe(mx.arange(L))
x = self.transformer(x, mask)
x = self.context_recurrent(x) + x
x = self.transformer2(x, mask)
x = self.out_proj(x)
return x
If self.context_recurrent is any of nn.LSTM, RNN, or GRU, the above crash happens.
To Reproduce
Insert x = nn.RNN(x) between two transformer layers causes the crash.
Complete code https://github.com/domschl/mlx-poet/blob/cedac548256a1bd2a1bb33362cf9d99f22a360c7/mlx_poet_bug.py
(requires pip install ml-indie-tools)
Expected behavior
No crash, and if necessary clear error message. I've checked that there is no tensor-shape problem.
Desktop (please complete the following information):
- OS Version: 14.4.1 (23E224)
- Version 0.12.0
Additional context Add any other context about the problem here.
Hi @domschl, thanks for the bug report. The bug is from the compilation. There is a subgraph that is too big (actually just has too many inputs) to fuse into a single kernel but compile still tries and fails. This big cryptic string is actually a representation of the graph to be fused.
We 'll look into fixing it (ie compile should break the subgraph in two smaller ones). In the meantime, you could disable compile and the code should run fine.
Tx! Confirmed: without compilation it works fine.
Could you please share the error message and a way to reproduce it?