mamba
mamba copied to clipboard
[Feature] Batch Decoding
Hi! @tridao
Thanks to @agiwave for the base code to get me started. (Implementation on CPU which isn't fully working)
I present an implementation on GPU which is working below.
There are 3 scripts:
- mamba_simple.py - which has all the logic of decoding N tokens.
- test_mamba_ssm_state.py - which shows that batch decoding works.
- time_measure.py - which measures the timing between batching decoding N=1 and N>1 (You can see the speedup is 10x atleast)