batched inference
Original PR: https://github.com/exo-explore/exo/pull/108
~~It was built on top of older code, needs a lot of refactoring~~ refactored now currently on pause - need to think broadly for distributed training on exo
Does this work as expected? If you look at mlx_parallm, they use a BatchedKVCache implementation to handle the kv cache for batches https://github.com/willccbb/mlx_parallm/blob/80b18ab49b80e6f8d82d89347ab32f44b35f8942/mlx_parallm/utils.py#L201
I'm not sure how it would work with the current implementation. It looks like one cache is used for all the requests which is probably not want we want here.
@AlexCheema https://github.com/willccbb/mlx_parallm/blob/80b18ab49b80e6f8d82d89347ab32f44b35f8942/mlx_parallm/utils.py#L201 is exactly the same as what I did, it's just that it stores batch_size as a variable, i just infer it from the input shape