mlx icon indicating copy to clipboard operation
mlx copied to clipboard

[BUG] libc++abi crash when using recurrent layer and transformer

Open domschl opened this issue 1 year ago • 2 comments

Describe the bug

(libc++abi: terminating due to uncaught exception of type std::runtime_error: 
[compile] Too many inputs/outputs fused in the Metal Compiled primitive which 
exhausted the available argument buffers for the kernel. Please file an issue with 
the function that results in this error. The name of the kernel is
'Nf4MultiplyABOf4AddEFPf4AddOGQf4AddPHRf4AddQISf4MultiplyRJTf4AddDSUf4MultiplyCTVf4SquareRWf4MultiplyVLXf4AddKWYf4SqrtXZf4AddYMAAf4DivideUZABf4SubtractNAA_VVVVVVVVVVVVV_f4f4f4f4f4f4f4f4f4f4f4f4f4_11160318154034397263_strided_dynamic')

In:

def __call__(self, x):
        L = x.shape[1]
        mask = nn.MultiHeadAttention.create_additive_causal_mask(L)
        x = self.embed(x)
        x = x + self.pe(mx.arange(L))
        x = self.transformer(x, mask)
        x = self.context_recurrent(x) + x
        x = self.transformer2(x, mask)
        x = self.out_proj(x)
        return x

If self.context_recurrent is any of nn.LSTM, RNN, or GRU, the above crash happens.

To Reproduce

Insert x = nn.RNN(x) between two transformer layers causes the crash.

Complete code https://github.com/domschl/mlx-poet/blob/cedac548256a1bd2a1bb33362cf9d99f22a360c7/mlx_poet_bug.py (requires pip install ml-indie-tools)

Expected behavior

No crash, and if necessary clear error message. I've checked that there is no tensor-shape problem.

Desktop (please complete the following information):

  • OS Version: 14.4.1 (23E224)
  • Version 0.12.0

Additional context Add any other context about the problem here.

domschl avatar May 01 '24 15:05 domschl

Hi @domschl, thanks for the bug report. The bug is from the compilation. There is a subgraph that is too big (actually just has too many inputs) to fuse into a single kernel but compile still tries and fails. This big cryptic string is actually a representation of the graph to be fused.

We 'll look into fixing it (ie compile should break the subgraph in two smaller ones). In the meantime, you could disable compile and the code should run fine.

angeloskath avatar May 01 '24 17:05 angeloskath

Tx! Confirmed: without compilation it works fine.

domschl avatar May 01 '24 18:05 domschl

Could you please share the error message and a way to reproduce it?

awni avatar Jun 24 '24 03:06 awni