accelerated-scan
accelerated-scan copied to clipboard
Integrate with Mamba
@proger
Awesome work! Always appreciate the wonderful contributions of OSS advancing the frontiers of research.
I know you've done a number of experiments comparing various scan implementations in your other repo nanokitchen -- would it make sense to integrate accelerated-scan as an alternative backend to Mamba? Would be happy to work on this if you think it makes sense.
@jeromeku thank you for the kind words! Glad you checked out nanokitchen as well.
It would be indeed possible to use Accelerated Scan for Mamba as is, however would work best for experimental purposes — Mamba kernel is already designed to achieve the best performance for that architecture.
Concretely, Mamba kernel fuses cub::BlockScan with SSM state expansion operations: A matrix expands every gate dimension (a gate is called delta in Mamba, it's expected that those deltas are stored in log space) into a 16-dimensional SSM, B respectively expands every input token dimension to match gate expansion and C collapses every SSM back. Accelerated Scan will have to accept expanded SSMs as inputs and waste precious memory bandwidth.
Mamba's review gives a hint that memory footprint could be improved for that kernel — a good direction would be to understand why is that the case. Reference: https://openreview.net/forum?id=AL1fq05o7H¬eId=T6WJZb30sz
I found that @srush has done this exact fusion of the SSM bits into the Triton forward kernel here: https://github.com/srush/annotated-mamba/issues/1#issuecomment-1885866368
Yeah thanks! Your repo was super helpful for that, we couldn't figure out how to do the two value scan.
Unfortunately I'm stuck now on the backwards. Need to do the scan right-to-left. I see that you do it by loading values in reverse order. Unfortunately we need to reverse the tensor in local memory (or repeat a lot of computation).
Any ideas? I think I might try making an LxL matrix and doing a dot? It seems like overkill, but I'm stuck for other methods.
@proger @srush will take a closer look and report back...
There's some discussion about making a reverse tl.associative_scan in https://github.com/openai/triton/issues/2930
Yes, that's issue is from me as well.
The reading the memory reverse trick is nice in your codebase. The problem is that for Mamba, you need to run the backward scan on an intermediately calculated tensor that is too large to store. Therefore you need to either reverse it in memory or have a reverse associative scan.
@proger
Any luck integrating a reverse option to the Triton backend?
Trying to get up to speed with MLIR :)
I sent them a PR for a flip function at the triton level which should be okay: https://github.com/openai/triton/pull/2954 Although would be interesting to do something more low-level
@srush
Thanks -- saw that PR and agree that a more low-level approach would be a worthwhile exercise. Always helps to understand how things work underneath the hood. MLIR is a bit of a beast.
FYI, this series of tutorials is a great intro to MLIR. Also, NVIDIA's cutlass library has similar abstractions (i.e., GEMM hierarchy) as triton, though triton is clearly more extensible to a wider variety of problems and backends.
It might be something similar to convert_layout from distribute to distribute in most cases. Feel free to take a look at relevant code.
I do think the current python solution is more elegant though.