mamba icon indicating copy to clipboard operation
mamba copied to clipboard

[Feature] Batch Decoding

Open AnaRhisT94 opened this issue 7 months ago • 0 comments

Hi! @tridao

Thanks to @agiwave for the base code to get me started. (Implementation on CPU which isn't fully working)

I present an implementation on GPU which is working below.

There are 3 scripts:

  1. mamba_simple.py - which has all the logic of decoding N tokens.
  2. test_mamba_ssm_state.py - which shows that batch decoding works.
  3. time_measure.py - which measures the timing between batching decoding N=1 and N>1 (You can see the speedup is 10x atleast)

AnaRhisT94 avatar Jul 18 '24 12:07 AnaRhisT94