mamba icon indicating copy to clipboard operation
mamba copied to clipboard

add feature / fix bug: I fixed the `kNRows` feature in forward

Open MzeroMiko opened this issue 1 year ago • 2 comments

Thank you for sharing this splendid work!

I found that kNRows is always 1 in original selective_scan, and I observed that if I use greater kNRows in selective scan, the faster the code would run. The phenomenon is consistent with mamba.py, when adding d_state, the time consumption keeps. Though it is not strictly right, but adding the burden of one thread and reducing the number of blocks (as SM is limited) really works in most of cases.

So I reopen that feature which may be deprecated in original selective_scan, and fixed some bugs related to it. I have tested with pytest tests/ops/test_selective_scan_.py (which you may delete later), and all tests pass.

Note that I have only fixed the forward procedure, so in backward, nrows is still 1.

Before Merging: I found that, when I uncomment all alternative parameters, the test is not all pass. However, mamba_ssm-1.1.3.post1+cu122torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl acts the same.

MzeroMiko avatar Feb 04 '24 15:02 MzeroMiko

@MzeroMiko could you please share you benchmark numbers and platform? I see that computation slows down quite a bit when using nrows>1. Am I missing something? These are the times I get on A100 80GB GPU image

My benchmark code can be found here https://github.com/state-spaces/mamba/issues/27#issuecomment-1930747882

Thanks!

apoorv2904 avatar Feb 09 '24 20:02 apoorv2904

Thank you very much, @apoorv2904. You are right, and I nearly failed to reproduce the results I have observed before. These days, I have been working on it. (the environment I use is 4090 24G, with py310+cu121+torch2.2)

  1. I added nrow feature in backward to better compare with different nrow settings.
  2. I compared my code (selective_scan_test here, or selective_scan_core in VMamba) with mamba_ssm rather than selective_scan_ref, and keeps no difference (tested all pass with test file).
  3. I realised that the issue proves nothing here, since raising d_state only inference the flops in SSM (nearly equals selective scan) while raising d_model or seqlen inferences the whole mamba model. As SSM is fast compared to the whole model + data loading, the speed difference is small and hard to observe (which is one possibility to that issue).
  4. I used my newly written simple benchmark, and found the results are consistent with yours. It seems that raising nrows would only make the code slower, until I finally realised that ***the test which shows raising the nrow will raise the speed, was done in 7x7 feature maps, which means seqlen is 49! extremely small! ***. Then I add seqlen=64 in testing, and found in some fwdnrow+bwdnrow patterns, the speed is fast, see log for details. Though I still do not know how bwd codes inferences the fwd procedure.
  5. I modified your benchmark, and the results are consistent with test_selective_scan_speed, see log for details. To conclude, with short seqlen, bigger nrow may lead to faster speed, but the reason remains unknown.

MzeroMiko avatar Feb 17 '24 13:02 MzeroMiko