mlx-swift-examples icon indicating copy to clipboard operation
mlx-swift-examples copied to clipboard

Crash on `concatenate` after latest update

Open DePasqualeOrg opened this issue 1 year ago • 17 comments

After https://github.com/ml-explore/mlx-swift-examples/commit/ab94ffc2f31a70ead3c7007afaf97a225ed3ec90, I'm getting a crash the second time I try to generate text with my app, which uses mlx-libraries. I can't reproduce this with the LLMEval example app at the moment, but I'll try to find the cause.

MLX error: [concatenate] All the input array dimensions must match exactly except for the concatenation axis. However, the provided shapes are (512,8,1,128), (1,8,512,128), and the concatenation axis is 2. at /Users/<user>/Library/Developer/Xcode/DerivedData/<app name>-ejjtjaklhfhyarhbwjdbxiatlsar/SourcePackages/checkouts/mlx-swift/Source/Cmlx/include/mlx/c/ops.cpp:217

DePasqualeOrg avatar Aug 29 '24 21:08 DePasqualeOrg

Which model were you running?

awni avatar Aug 29 '24 21:08 awni

So far it has happened on Phi 3.5 mini 4-bit and Llama 3.1 9B 4-bit. I haven't tested other models yet.

DePasqualeOrg avatar Aug 29 '24 21:08 DePasqualeOrg

Can you provide more details on what exactly you are running? Most likely the line it's breaking at is in the new KV cache. It looks like one of the inputs to that function has the wrong order.

Are you using custom model code or the same models as in the example?

awni avatar Aug 29 '24 21:08 awni

I'm using the models from mlx-libraries. Generally this happens on the second or third prompt in a conversation. I'm still trying to investigate this on my end but wanted to open this issue in case others are having similar problems.

DePasqualeOrg avatar Aug 29 '24 21:08 DePasqualeOrg

I was using mlx-community/Phi-3-mini-4k-instruct-4bit as the primary use case so I know that one works generally.

Is there something I can do to reproduce the issue? I am happy to debug it.

davidkoski avatar Aug 29 '24 21:08 davidkoski

Generally this happens on the second or third prompt in a conversation

How do you do the conversation? How is the state carried from one call to generate to the next?

davidkoski avatar Aug 29 '24 21:08 davidkoski

I use the prompt templates that are commonly used for each model to represent the conversation, adding to them for each new prompt and response, and passing the updated template to generate when a new prompt is submitted. I've had to build this myself, since swift-transformers doesn't include it, although this may change soon. I'll post an update here when I can reproduce this with LLMEval.

DePasqualeOrg avatar Aug 29 '24 22:08 DePasqualeOrg

I've been able to reproduce this with LLMEval after updating the swift-transformers dependency to use the latest commit on the main branch. Short prompts work as expected, but after submitting a long prompt (about 5600 characters) with Llama 3 9B 4-bit, I get this crash:

MLX error: [scaled_dot_product_attention] mismatching batch dimension for input with shape (512,8,170,128). at /Users/<user>/Library/Developer/Xcode/DerivedData/mlx-swift-examples-eiqkagimbcumwufwrjncqseqpfjo/SourcePackages/checkouts/mlx-swift/Source/Cmlx/include/mlx/c/fast.cpp:61

And with Phi 3.5 mini:

MLX error: [concatenate] All the input array dimensions must match exactly except for the concatenation axis. However, the provided shapes are (512,32,2,96), (1,32,512,96), and the concatenation axis is 2. at /Users/<user>/Library/Developer/Xcode/DerivedData/mlx-swift-examples-eiqkagimbcumwufwrjncqseqpfjo/SourcePackages/checkouts/mlx-swift/Source/Cmlx/include/mlx/c/ops.cpp:217

I noticed that the active memory indicator in the app grows to a very large number when this happens.

DePasqualeOrg avatar Aug 30 '24 01:08 DePasqualeOrg

ok perfect, I have a repro using the output of the previous runs:

p2.txt

davidkoski avatar Aug 30 '24 01:08 davidkoski

actually a different error:

<|assistant|>MLX error: [scaled_dot_product_attention] mismatching batch dimension for input with shape (512,32,182,96). at /Users/dkoski/Library/Developer/Xcode/DerivedData/mlx-swift-examples-eimbjcofifunwybkcvhnzjbqwyri/SourcePackages/checkouts/mlx-swift/Source/Cmlx/include/mlx/c/fast.cpp:61

hopefully that is related though. I also see the big memory spike

(actually this matches your other error)

davidkoski avatar Aug 30 '24 01:08 davidkoski

So something is off in the KVCache. The shape of the keys after the prompt on the python side:

(1, 32, 512, 96)

swift:

- 0 : 256
- 1 : 32
- 2 : 256
- 3 : 96

The 256 vs 512 is because I messed with the prefill step size. Anyway the 0 dimension is not right.

davidkoski avatar Aug 30 '24 02:08 davidkoski

This comes from different shapes as input to Attention:

python: (1, 512, 3072) swift: [256, 1, 3072]

aha:

        model(y[:prefill_step_size][None], cache=cache)

does not translate to:

            _ = model(y[..<parameters.prefillStepSize, .newAxis], cache: cache)

the order of the trailing [None] is actually first:

y[.newAxis, ..<parameters.prefillStepSize]

davidkoski avatar Aug 30 '24 02:08 davidkoski

It is amazing that this gross mismatch of shapes ... mostly works. It sure would be nice to have some typing on shapes. I suppose we could use precondition

davidkoski avatar Aug 30 '24 02:08 davidkoski

@DePasqualeOrg thank you so much for reporting this! Your info got a quick repro and I was able to track down the issue. You can try kvcache2 from #115

davidkoski avatar Aug 30 '24 02:08 davidkoski

Fantastic, thank you! I tested this with Phi 3.5 mini and Llama 3.1 9B, and it mostly seems to work, but on longer, multi-turn prompts I got garbled output from Phi 3.5 mini and special tokens like assistant<|end_header_id|> from Llama 3.1 9B. I guess this is due to the new KV cache?

I'm also curious how you would recommend estimating the required memory for a given prompt with this new approach.

DePasqualeOrg avatar Aug 30 '24 11:08 DePasqualeOrg

I tested this with Phi 3.5 mini and Llama 3.1 9B, and it mostly seems to work, but on longer, multi-turn prompts I got garbled output from Phi 3.5 mini and special tokens like assistant<|end_header_id|> from Llama 3.1 9B

The handling of the RoPE positional encodings is not quite right for both Llama 3.1 and Phi 3.5. So if you're prompt + generation is very long (like 4k tokens or more) that might explain it. The new KV Cache shouldn't change the results at all.. if you are finding that it does, then that is a bug. We'll want to update to the latest MLX to fix this.

I'm also curious how you would recommend estimating the required memory for a given prompt with this new approach.

Since the attention steps are fixed at 512 the maximum size of the attention scores is now 512 * 512 * num_heads * 2 which is not that big. The memory bottleneck for long prompts will most likely be the memory used by the KV cache. That will scale as the product of the following factors:

  • num layers
  • 2 (keys and values)
  • length of prompt + generation
  • num_kv_heads
  • head_dim
  • 2 bytes

awni avatar Aug 30 '24 13:08 awni

The new KV Cache shouldn't change the results at all.. if you are finding that it does, then that is a bug. We'll want to update to the latest MLX to fix this.

Whenever you're able to update to the latest MLX, I'll test this again and see if that solves the problem.

DePasqualeOrg avatar Sep 15 '24 18:09 DePasqualeOrg

I haven't encountered this lately, so I believe the issue is resolved.

DePasqualeOrg avatar Nov 15 '24 13:11 DePasqualeOrg