exo icon indicating copy to clipboard operation
exo copied to clipboard

batched inference

Open varshith15 opened this issue 1 year ago • 2 comments

Original PR: https://github.com/exo-explore/exo/pull/108

~~It was built on top of older code, needs a lot of refactoring~~ refactored now currently on pause - need to think broadly for distributed training on exo

varshith15 avatar Sep 10 '24 18:09 varshith15

Does this work as expected? If you look at mlx_parallm, they use a BatchedKVCache implementation to handle the kv cache for batches https://github.com/willccbb/mlx_parallm/blob/80b18ab49b80e6f8d82d89347ab32f44b35f8942/mlx_parallm/utils.py#L201

I'm not sure how it would work with the current implementation. It looks like one cache is used for all the requests which is probably not want we want here.

AlexCheema avatar Sep 24 '24 18:09 AlexCheema

@AlexCheema https://github.com/willccbb/mlx_parallm/blob/80b18ab49b80e6f8d82d89347ab32f44b35f8942/mlx_parallm/utils.py#L201 is exactly the same as what I did, it's just that it stores batch_size as a variable, i just infer it from the input shape

varshith15 avatar Sep 25 '24 05:09 varshith15