mamba
mamba copied to clipboard
Unable to directly compute parameters and FLOPs for Mamba models due to Triton and CUDA implementations
I’m encountering an issue when working with the Mamba project (both Mamba1 and Mamba2). Since the implementation of Mamba relies on both Triton and CUDA, it’s not possible to directly use standard PyTorch tools (e.g., torchsummary or torchinfo) to compute the number of parameters and FLOPs for the model. My questions are: How should the number of parameters and FLOPs be calculated for Mamba block? Could you provide the specific calculation methods? Any guidance or references would be greatly appreciated! Thank you!
The model parameters are still defined in pytorch so it's just sum(p.numel() for p in model.parameters()).
For FLOPS you can calculate by hand, or search the issues on this repo.