mamba
mamba copied to clipboard
add feature / fix bug: I fixed the `kNRows` feature in forward
Thank you for sharing this splendid work!
I found that kNRows
is always 1
in original selective_scan
, and I observed that if I use greater kNRows
in selective scan
, the faster the code would run. The phenomenon is consistent with mamba.py, when adding d_state
, the time consumption keeps. Though it is not strictly right, but adding the burden of one thread and reducing the number of blocks (as SM is limited) really works in most of cases.
So I reopen that feature which may be deprecated in original selective_scan
, and fixed some bugs related to it.
I have tested with pytest tests/ops/test_selective_scan_.py
(which you may delete later), and all tests pass.
Note that I have only fixed the forward procedure, so in backward, nrows
is still 1
.
Before Merging: I found that, when I uncomment all alternative parameters, the test is not all pass. However, mamba_ssm-1.1.3.post1+cu122torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
acts the same.
@MzeroMiko could you please share you benchmark numbers and platform? I see that computation slows down quite a bit when using nrows>1. Am I missing something? These are the times I get on A100 80GB GPU
My benchmark code can be found here https://github.com/state-spaces/mamba/issues/27#issuecomment-1930747882
Thanks!
Thank you very much, @apoorv2904. You are right, and I nearly failed to reproduce the results I have observed before. These days, I have been working on it. (the environment I use is 4090 24G, with py310+cu121+torch2.2)
- I added
nrow
feature in backward to better compare with differentnrow
settings. - I compared my code (
selective_scan_test
here, orselective_scan_core
in VMamba) withmamba_ssm
rather thanselective_scan_ref
, and keeps no difference (tested all pass with test file). - I realised that the issue proves nothing here, since raising
d_state
only inference the flops in SSM (nearly equals selective scan) while raisingd_model
orseqlen
inferences the whole mamba model. As SSM is fast compared tothe whole model + data loading
, the speed difference is small and hard to observe (which is one possibility to that issue). - I used my newly written
simple benchmark
, and found the results are consistent with yours. It seems that raising nrows would only make the code slower, until I finally realised that ***the test which shows raising the nrow will raise the speed, was done in 7x7 feature maps, which means seqlen is 49! extremely small! ***. Then I addseqlen=64
in testing, and found in somefwdnrow+bwdnrow
patterns, the speed is fast, see log for details. Though I still do not know how bwd codes inferences the fwd procedure. - I modified your
benchmark
, and the results are consistent withtest_selective_scan_speed
, see log for details. To conclude, with shortseqlen
, biggernrow
may lead to faster speed, but the reason remains unknown.