mamba
mamba copied to clipboard
does forward/eval from a trained mamba model require cuda as well?
codes in selective_scan_fwd() of selective_scan.cpp seem to suggest even forward from a trained model would require cuda, which might be inconvenient when running models in production environments. Any idea how to do model forward on a CPU-only machine? Thanks
Yup, it's only implemented for CUDA for now. You can look at the selective_scan_ref for the pure pytorch implementation that should run on CPU (though probably quite slow).
thanks, will look into it
You can check this fork. It works on CPU
@kroggen Thanks for the cpu version. Would be nice if you added this as a PR, currently using your code for debugging.
Inference of Mamba models in pure C
https://github.com/kroggen/mamba.c
Recurrent mode only, for simplicity
Faster than pytorch (in default mode) on CPU