mamba icon indicating copy to clipboard operation
mamba copied to clipboard

Minimal reference implementation

Open EelcoHoogendoorn opened this issue 1 year ago • 5 comments
trafficstars

Thanks so much for providing this code; looks very useful and reproducible.

As I understand, the custom scan kernel can be quite important to performance considerations, so it is great to see it here as well.

However, as a suggestion, I think itd be super neat to also see a minimal Mamba reference implementation, with minimal dependencies, simply for clarity of exposition; something that could be unit tested to behave the same at least on small datasets, as the custom kernel. Would that be a lot of work? Does it already exist somewhere? If a torch version exists id be happy to port it to a JAX version as well.

EelcoHoogendoorn avatar Dec 05 '23 20:12 EelcoHoogendoorn

There's a reference implementation of the selective scan in Pytorch here. That's the main primitive that requires CUDA.

tridao avatar Dec 05 '23 20:12 tridao

Right; I did see there is a reference implementation; I just wondered how close we should consider it to being 'minimal'.

How close could mamba get to this kind of minimalisms?

EelcoHoogendoorn avatar Dec 05 '23 20:12 EelcoHoogendoorn

The core is actually just a for loop, the code will simplify a lot if you only take the path where B/C are input-dependent.

tridao avatar Dec 06 '23 08:12 tridao

Hi all, I wrote a minimal implementation here: https://github.com/johnma2006/mamba-minimal/tree/master. Hope it helps!

johnma2006 avatar Dec 20 '23 11:12 johnma2006

Hi all, I wrote a minimal implementation here: https://github.com/johnma2006/mamba-minimal/tree/master. Hope it helps!

Thanks, that looks really clean, and should be trivial to port to JAX. From what i understand using JAX scan also isnt competitive for LLM-scale models but my intuition is itd be fine for some of the smaller stuff id want to try it on.

EelcoHoogendoorn avatar Dec 21 '23 06:12 EelcoHoogendoorn