mamba
mamba copied to clipboard
Minimal reference implementation
Thanks so much for providing this code; looks very useful and reproducible.
As I understand, the custom scan kernel can be quite important to performance considerations, so it is great to see it here as well.
However, as a suggestion, I think itd be super neat to also see a minimal Mamba reference implementation, with minimal dependencies, simply for clarity of exposition; something that could be unit tested to behave the same at least on small datasets, as the custom kernel. Would that be a lot of work? Does it already exist somewhere? If a torch version exists id be happy to port it to a JAX version as well.
There's a reference implementation of the selective scan in Pytorch here. That's the main primitive that requires CUDA.
Right; I did see there is a reference implementation; I just wondered how close we should consider it to being 'minimal'.
How close could mamba get to this kind of minimalisms?
The core is actually just a for loop, the code will simplify a lot if you only take the path where B/C are input-dependent.
Hi all, I wrote a minimal implementation here: https://github.com/johnma2006/mamba-minimal/tree/master. Hope it helps!
Hi all, I wrote a minimal implementation here: https://github.com/johnma2006/mamba-minimal/tree/master. Hope it helps!
Thanks, that looks really clean, and should be trivial to port to JAX. From what i understand using JAX scan also isnt competitive for LLM-scale models but my intuition is itd be fine for some of the smaller stuff id want to try it on.