mamba icon indicating copy to clipboard operation
mamba copied to clipboard

Rocm support

Open Wintoplay opened this issue 1 year ago • 7 comments
trafficstars

Please consider adding rocm support for amd gpu

Wintoplay avatar Dec 19 '23 03:12 Wintoplay

What compatibility in libraries would be needed in order to have this work successfully in ROCm?

j-dominguez9 avatar Jan 28 '24 18:01 j-dominguez9

We do not have experience with ROCm, but ofc we'd welcome community contribution on this

tridao avatar Jan 28 '24 18:01 tridao

Hi @tridao , thanks for answering. I think there is enough of a demand in the ROCm community to make this work. I went through the files, but just to make sure I'm not missing anything, is it the case that the only dependencies are :

    python_requires=">=3.7",
    install_requires=[
        "torch",
        "packaging",
        "ninja",
        "einops",
        "triton",
        "transformers",
        "causal_conv1d>=1.1.0",
    ],

If that's the case, this should be relatively easy to port.

j-dominguez9 avatar Jan 28 '24 18:01 j-dominguez9

There's CUDA code in causal_conv1d but that's optional, we can use torch's conv1d. There's CUDA code in this repo for the selective_scan operation (csrc) and maybe it can work w HIP.

tridao avatar Jan 28 '24 19:01 tridao

I found the simple one written in PyTorch. Compatible with ROCm.

https://github.com/alxndrTL/mamba.py/issues/22

supersonictw avatar Apr 19 '24 19:04 supersonictw

We have a working version of mamba on ROCm. We've been able to run generation on AMD's MI210, and the unit tests are passing for the port. It uses pytorch's cpp extensions to selectively build kernel code based on whether the system is running CUDA or ROCm, so the port is able to be built in both systems. https://github.com/EmbeddedLLM/mamba-rocm

kliuae avatar Apr 22 '24 09:04 kliuae

The official PR from AMD had been merged. https://github.com/state-spaces/mamba/pull/359

supersonictw avatar Jul 18 '24 01:07 supersonictw