[BOUNTY - $500] JAX Inference Engine
Why JAX? Read this: https://neel04.github.io/my-website/blog/pytorch_rant/
The deliverable is a JAXInferenceEngine that can run Llama.
Hi Alex, I'd like to take up on this bounty. Can you elaborate more on what needs to be done? Just the title and the commit are not the helpful :-)
Hi Alex, I'd like to take up on this bounty. Can you elaborate more on what needs to be done? Just the title and the commit are not the helpful :-)
Assigned!
We would want a JAX InferenceEngine implementation.
A PyTorch InferenceEngine was just implemented here if you want a reference: https://github.com/exo-explore/exo/pull/139
Hi @AlexCheema,
I would like to take up this bounty.
As I was looking into the existing implementations, I found that tinygrad only supports llama while mlx supports multiple models.
Do we have any priority on which models we want to be supported using JAX?
Hi @AlexCheema,
I would like to take up this bounty.
As I was looking into the existing implementations, I found that
tinygradonly supports llama whilemlxsupports multiple models. Do we have any priority on which models we want to be supported using JAX?
Let's start with Llama only. Then we can add more in separate bounties. Assigned - good luck!
Hey @AlexCheema wanted to check if there is any update on this, and @brightprogrammer would love to help you out on this bounty can you update me the progress of issue?