mlx-examples icon indicating copy to clipboard operation
mlx-examples copied to clipboard

feat(mlx_lm): support batch input in `generate()`

Open llllvvuu opened this issue 1 year ago • 6 comments

The prompt argument can now be either a str or list[str].

This is based on @willccbb's implementation at https://github.com/willccbb/mlx_parallm; I noticed that it aligned with the KVCache upgrades in #911.

The change to generate() is backwards-compatible.

The changes to generate_step(), top_p_sampling(), and min_p_sampling() are backwards-incompatible in order to unify shapes; this could be changed by adding a few if-statements, if preferred.

llllvvuu avatar Aug 21 '24 05:08 llllvvuu

Kind of interesting: for quantized models, the throughput is doesn't go up a lot between small bs (bs=1,2,3,4), but then it starts to go up a lot at higher bs, which is the opposite of what I expected intuitively. For unquantized models the throughput does goes up between small bs. I observe the same on @willccbb's original repo.

llllvvuu avatar Aug 26 '24 05:08 llllvvuu

I think it makes sense to minimize the complexity to the generate function (which is becoming a bit spaghetti) to split out the batched generation into a separate function called batch_generate. I would simplify that function to have fewer arguments (like no formatter, no printing during generation, verbose only prints the timings (e.g. as you have it now).

Also maybe more tricky is the fact that I think for this to be correct, the causal masks need to consider the left padding in the input (please correct me if I'm wrong about that). This has two implications:

  1. Probably we'd need to add a mask parameter to the model __call__ functions and provide an appropriately constructed mask for the batch case.
  2. The Rotating KV cache will be broken in this case (it keeps the initial tokens which would be the padded tokens) and when rotates the mask would need to be updated to consider the padding (which is a bit complicated/tedious). In this case I may suggest disabling this option entirely..

Let me know what you think about the above.

awni avatar Aug 29 '24 13:08 awni

I think it makes sense to minimize the complexity to the generate function (which is becoming a bit spaghetti) to split out the batched generation into a separate function called batch_generate. I would simplify that function to have fewer arguments (like no formatter, no printing during generation, verbose only prints the timings (e.g. as you have it now).

Makes sense to me, will implement.

Also maybe more tricky is the fact that I think for this to be correct, the causal masks need to consider the left padding in the input (please correct me if I'm wrong about that). This has two implications:

  1. Probably we'd need to add a mask parameter to the model __call__ functions and provide an appropriately constructed mask for the batch case.

Yes, this sounds straightforward enough.

  1. The Rotating KV cache will be broken in this case (it keeps the initial tokens which would be the padded tokens) and when rotates the mask would need to be updated to consider the padding (which is a bit complicated/tedious). In this case I may suggest disabling this option entirely..

I'll do a bit of thinking if there's an easy way to handle this, otherwise I'll remove that parameter in batch_generate.

Will update when these changes are ready!

llllvvuu avatar Aug 29 '24 13:08 llllvvuu

@llllvvuu are you coming back to this?

awni avatar Sep 27 '24 19:09 awni

hey @awni , sorry for the delay, I'd been job hunting this month. I should be able to get back to this in ~a week

llllvvuu avatar Sep 28 '24 00:09 llllvvuu

No worries, just checking. I'll follow up in a week or so.

awni avatar Sep 28 '24 00:09 awni

Just realised the attention mask has been mentioned in this PR, which is the reason I raised this issue #1044

nath1295 avatar Oct 15 '24 21:10 nath1295

Sorry for the delay @awni . I took advantage of https://github.com/ml-explore/mlx-examples/pull/1173 to update this PR. It is pending versioned release of https://github.com/ml-explore/mlx/pull/1726 for the mask dtype.

I noticed one other potential issue: For absolute/rotary positional encodings, the position IDs of padded prompts won't start from 0 (this becomes more tricky if a padded prompt cache is added as then the position IDs should become non-contiguous IIUC). I'm not sure what the priority of this is or if it requires any change to mx.fast.rope.

llllvvuu avatar Dec 27 '24 23:12 llllvvuu

Any update? we do need this to support parallel generation.

qinxuye avatar Jan 13 '25 10:01 qinxuye

Will get to this soon. Sorry for the delay.

awni avatar Jan 13 '25 14:01 awni

Sorry to follow up again, but we really need this capability. May I ask about the progress?

qinxuye avatar Mar 10 '25 04:03 qinxuye

@qinxuye could you say a bit more about what you are looking for?

awni avatar Mar 11 '25 17:03 awni

@qinxuye could you say a bit more about what you are looking for?

Oh, now we want to support continuous batching for MLX in Xinference, but if MLX cannot support batch inference, we don't know how to move forward.

qinxuye avatar Mar 13 '25 09:03 qinxuye

the code to implement batch inference is pretty simple, you can just copy the code from mlx_parallm (or probably better: https://github.com/N8python/gsm-mlx) and add it to your codebase. the hard part is orchestrating the continuous batching, not the batched KV caching itself. this is largely why I stopped development on mlx_parallm, that would have been the natural next step but I didn't have an immediate need and it seemed like it'd take a lot of time to get it working properly.

willccbb avatar Mar 14 '25 18:03 willccbb

the code to implement batch inference is pretty simple, you can just copy the code from mlx_parallm (or probably better: https://github.com/N8python/gsm-mlx) and add it to your codebase. the hard part is orchestrating the continuous batching, not the batched KV caching itself. this is largely why I stopped development on mlx_parallm, that would have been the natural next step but I didn't have an immediate need and it seemed like it'd take a lot of time to get it working properly.

Thanks for your advice, that makes sense, we could just implement the batch logic first, and about the continuous batching, we already have some code base in our project, I will work on it to see how to integrate them together.

qinxuye avatar Mar 16 '25 08:03 qinxuye

@qinxuye it looks like there is a new repo for mlx-lm, I'll take a look sometime and see if I can open something over there. I also never did resolve the position ids thing (although it didn't seem to cause any issues in my tests - maybe the LLMs I used)

llllvvuu avatar Mar 23 '25 22:03 llllvvuu