mamba icon indicating copy to clipboard operation
mamba copied to clipboard

Windows Support

Open Phylliida opened this issue 1 year ago • 11 comments

I'm able to compile causal-conv1d by adding

                        "-DWIN32_LEAN_AND_MEAN",

To the nvcc flags.

When compiling mamba, after adding -DWIN32_LEAN_AND_MEAN to nvcc flags, I find I need to add

#ifndef M_LOG2E
#define M_LOG2E 1.4426950408889634074
#endif

To selective_scan_common.h

Then it can get a little further, however it raises the following errors:

Y:\prog\python\thirdparty\mamba\csrc\selective_scan\selective_scan_bwd_kernel.cuh(493): error C2975: 'kIsEvenLen_': invalid template argument for 'Selective_Scan_bwd_kernel_traits', expected compile-time constant expression
Y:\prog\python\thirdparty\mamba\csrc\selective_scan\selective_scan_bwd_kernel.cuh(26): note: see declaration of 'kIsEvenLen_'
Y:\prog\python\thirdparty\mamba\csrc\selective_scan\selective_scan_bwd_kernel.cuh(521): note: see reference to function template instantiation 'void selective_scan_bwd_launch<32,4,input_t,weight_t>(SSMParamsBwd &,cudaStream_t)' being compiled
        with
        [
            input_t=c10::BFloat16,
            weight_t=c10::complex<float>
        ]
Y:\prog\python\thirdparty\mamba\csrc\selective_scan\selective_scan_bwd_bf16_complex.cu(9): note: see reference to function template instantiation 'void selective_scan_bwd_cuda<c10::BFloat16,c10::complex<float>>(SSMParamsBwd &,cudaStream_t)' being compiled
Y:\prog\python\thirdparty\mamba\csrc\selective_scan\selective_scan_bwd_kernel.cuh(493): error C2975: 'kIsVariableB_': invalid template argument for 'Selective_Scan_bwd_kernel_traits', expected compile-time constant expression
Y:\prog\python\thirdparty\mamba\csrc\selective_scan\selective_scan_bwd_kernel.cuh(26): note: see declaration of 'kIsVariableB_'
Y:\prog\python\thirdparty\mamba\csrc\selective_scan\selective_scan_bwd_kernel.cuh(493): error C2975: 'kIsVariableC_': invalid template argument for 'Selective_Scan_bwd_kernel_traits', expected compile-time constant expression
Y:\prog\python\thirdparty\mamba\csrc\selective_scan\selective_scan_bwd_kernel.cuh(26): note: see declaration of 'kIsVariableC_'
Y:\prog\python\thirdparty\mamba\csrc\selective_scan\selective_scan_bwd_kernel.cuh(493): error C2975: 'kDeltaSoftplus_': invalid template argument for 'Selective_Scan_bwd_kernel_traits', expected compile-time constant expression
Y:\prog\python\thirdparty\mamba\csrc\selective_scan\selective_scan_bwd_kernel.cuh(27): note: see declaration of 'kDeltaSoftplus_'
Y:\prog\python\thirdparty\mamba\csrc\selective_scan\selective_scan_bwd_kernel.cuh(493): error C2975: 'kHasZ_': invalid template argument for 'Selective_Scan_bwd_kernel_traits', expected compile-time constant expression
Y:\prog\python\thirdparty\mamba\csrc\selective_scan\selective_scan_bwd_kernel.cuh(27): note: see declaration of 'kHasZ_'

This might be related to this issue, something about the windows compiler being more strict. However the intervention is probably gonna be a little more involved and I haven't had much luck yet

Phylliida avatar Dec 05 '23 21:12 Phylliida

Unfortunately we've never tested windows paths, and it's not on the roadmap right now.

albertfgu avatar Dec 05 '23 23:12 albertfgu

Sorry if this is something you've already checked/covered @Phylliida but have you checked perhaps that you are building the code as C++20 (just guessing that the way constexpr and lambdas are used that it'll need to be that version of the language)?

EDIT: also that comment you link to, that links to a Stack Overflow post appears to be unrelated to either issue thread; it's talking about something completely different (I'd hazard a guess the commenter remembered a #define being useful for array declaration and was sharing it, even though it did not relate to the specific defines you mentioned there)

EDIT2: per https://learn.microsoft.com/en-us/cpp/c-runtime-library/math-constants it perhaps might be better to define _USE_MATH_DEFINES for like M_LOG2E to be defined

EDIT3: actually it looks like the code was updated a day or so ago to ask that it be compiled with C++17 (not 20 as I had guessed) maybe check if you have this the recent commit also? https://github.com/state-spaces/mamba/commit/023c25d47bb8f0b048db52c282fd4226a4035b0d

nat42 avatar Dec 06 '23 14:12 nat42

Nice, adding

                        "-D_USE_MATH_DEFINES",

to nvcc flags is a better alternative

Compiling with c++17 isn't enough, I get the errors listed above. Rn I'm trying to get c++20 working, no success yet

Edit: Ok looks like triton is a dependency, I'm trying out wheels prebuilt from here (scroll down to the bottom, extract the windows build, then pip install ___.whl for your version of python. I'm using 3.10 and Cuda 12.1)

Phylliida avatar Dec 06 '23 17:12 Phylliida

Okay I've successfully ran inference on Windows. I'm in python 3.9 cuda 12.1 I had to do the following things:

(do all of the following in x64 Native Tools Command Prompt for VS 2019)

compile causal-conv1d by adding

                        "-DWIN32_LEAN_AND_MEAN",

To the nvcc flags in setup.py

(you may also need to run

SET DISTUTILS_USE_SDK=1

)

Next, we need to install triton.

Download triton wheel from here scroll down to the bottom and download triton-dist windows-latest

extract it then run

pip3 install triton-2.1.0-cp39-cp39-win_amd64.whl

If you have a different version of python and cuda 11.8 you can use one from here instead though I haven't tested that

Next, you need to get the compiled libraries triton needs. You can download them from here, add the bin directory to your PATH

If you prefer to compile it yourself you can see the command here but be wary it'll take about 1-2 hours.

Finally, I just modified ops/selective_scan_interface.py to:

  1. Remove this line:
import selective_scan_cuda
  1. Replace
def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
                     return_last_state=False):
    """if return_last_state is True, returns (out, last_state)
    last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
    not considered in the backward pass.
    """
    return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)

with

def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
                     return_last_state=False):
    """if return_last_state is True, returns (out, last_state)
    last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
    not considered in the backward pass.
    """
    return selective_scan_ref(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)

it would be better to use the kernel, but until we can get it compiling on windows we can use the reference implementation in pure python instead.

With this setup I'm able to run inference using the 2.8b model (at fp16 or fp32) on a 3090.

For example:

Prompt:

User: What is the answer to life the universe and everything? Oracle:

Answer:

I don't know. I'm just a computer.

Phylliida avatar Dec 06 '23 21:12 Phylliida

I think I found a workaround for compiling this package for windows (however, I have not tested the impact on performance). MSVC has a problem with constexpr and can't handle passing them to templates as arguments (see this and this). The workaround is to replace constexpr with const static.

diff --git a/csrc/selective_scan/selective_scan_fwd_kernel.cuh b/csrc/selective_scan/selective_scan_fwd_kernel.cuh
index 440a209..b3ef2a8 100644
--- a/csrc/selective_scan/selective_scan_fwd_kernel.cuh
+++ b/csrc/selective_scan/selective_scan_fwd_kernel.cuh
@@ -306,14 +306,14 @@ template<int kNThreads, int kNItems, typename input_t, typename weight_t>
 void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
     // Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block
     // processing 1 row.
-    constexpr int kNRows = 1;
+    const static int kNRows = 1;
     BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
         BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] {
             BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] {
                 BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
                     using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, input_t, weight_t>;
-                    // constexpr int kSmemSize = Ktraits::kSmemSize;
-                    constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
+                    // const static int kSmemSize = Ktraits::kSmemSize;
+                    const static int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
                     // printf("smem_size = %d\n", kSmemSize);
                     dim3 grid(params.batch, params.dim / kNRows);
                     auto kernel = &selective_scan_fwd_kernel<Ktraits>;
diff --git a/csrc/selective_scan/static_switch.h b/csrc/selective_scan/static_switch.h
index 7920ac0..87493ef 100644
--- a/csrc/selective_scan/static_switch.h
+++ b/csrc/selective_scan/static_switch.h
@@ -16,10 +16,10 @@
 #define BOOL_SWITCH(COND, CONST_NAME, ...)                                           \
     [&] {                                                                            \
         if (COND) {                                                                  \
-            constexpr bool CONST_NAME = true;                                        \
+            const static bool CONST_NAME = true;                                     \
             return __VA_ARGS__();                                                    \
         } else {                                                                     \
-            constexpr bool CONST_NAME = false;                                       \
+            const static bool CONST_NAME = false;                                    \
             return __VA_ARGS__();                                                    \
         }                                                                            \
     }()

With those changes I can compile the package. It seems to work in PyTorch, but like I mentioned, I haven't tested performance or correctness. 😅

Grzego avatar Dec 10 '23 02:12 Grzego

I think I found a workaround for compiling this package for windows (however, I have not tested the impact on performance). MSVC has a problem with constexpr and can't handle passing them to templates as arguments (see this and this). The workaround is to replace constexpr with const static.

diff --git a/csrc/selective_scan/selective_scan_fwd_kernel.cuh b/csrc/selective_scan/selective_scan_fwd_kernel.cuh
index 440a209..b3ef2a8 100644
--- a/csrc/selective_scan/selective_scan_fwd_kernel.cuh
+++ b/csrc/selective_scan/selective_scan_fwd_kernel.cuh
@@ -306,14 +306,14 @@ template<int kNThreads, int kNItems, typename input_t, typename weight_t>
 void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
     // Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block
     // processing 1 row.
-    constexpr int kNRows = 1;
+    const static int kNRows = 1;
     BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
         BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] {
             BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] {
                 BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
                     using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, input_t, weight_t>;
-                    // constexpr int kSmemSize = Ktraits::kSmemSize;
-                    constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
+                    // const static int kSmemSize = Ktraits::kSmemSize;
+                    const static int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
                     // printf("smem_size = %d\n", kSmemSize);
                     dim3 grid(params.batch, params.dim / kNRows);
                     auto kernel = &selective_scan_fwd_kernel<Ktraits>;
diff --git a/csrc/selective_scan/static_switch.h b/csrc/selective_scan/static_switch.h
index 7920ac0..87493ef 100644
--- a/csrc/selective_scan/static_switch.h
+++ b/csrc/selective_scan/static_switch.h
@@ -16,10 +16,10 @@
 #define BOOL_SWITCH(COND, CONST_NAME, ...)                                           \
     [&] {                                                                            \
         if (COND) {                                                                  \
-            constexpr bool CONST_NAME = true;                                        \
+            const static bool CONST_NAME = true;                                     \
             return __VA_ARGS__();                                                    \
         } else {                                                                     \
-            constexpr bool CONST_NAME = false;                                       \
+            const static bool CONST_NAME = false;                                    \
             return __VA_ARGS__();                                                    \
         }                                                                            \
     }()

With those changes I can compile the package. It seems to work in PyTorch, but like I mentioned, I haven't tested performance or correctness. 😅

working solution. (compiled but haven't trained)

python 3.11.7 windows 10

Jacky56 avatar Feb 01 '24 20:02 Jacky56

@Phylliida hello, thanks for your method. But I don't understand what to be added after removing "import selective_scan_cuda" .In the class SelectiveScanFn , There are " out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus)" and "du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(u, delta, A, B, C, D, delta_bias, dout, x, None, ctx.delta_softplus, ) " in the forward and backward fuctions . Please help me.

RiceBunny1990 avatar Feb 03 '24 06:02 RiceBunny1990

@RiceBunny1990 You can skip any modifications to ops/selective_scan_interface.py after you successfully compile mamba kernels on windows. Which should be possible after doing the changes I posted previously.

Grzego avatar Feb 03 '24 18:02 Grzego

Is there a simple way to get the training and inference (without recompiling the CUDA kernels) working on Windows without using WSL?

F286 avatar Feb 04 '24 05:02 F286

@Phylliida @Grzego Thank you for your information, I have complied causal_conv1d 1.1.3.post1 and mamba 1.1.3.post1 successfully in python 3.10 + windows 11 x64 + torch 2.2 + cuda 12.1. However, when I try to import mamba, it will crash on import casual_conv1d_cuda, gives:

ImportError: DLL load failed while importing causal_conv1d_cuda: The specified module could not be found.

I have checked causal_conv1d_cuda.cp310-win_amd64.pyd's dependencies (AFAIK pyd is dll in windows), all its dependencies exist. image Any idea what causes it failed?

lyhyl avatar Feb 05 '24 15:02 lyhyl

I have managed to build mamba-ssm but for the life of mine , i cannot compile causal-conv1d, @Phylliida the "-DWIN32_LEAN_AND_MEAN", goes right into here : extra_compile_args={ "cxx": ["-O3"], "nvcc": append_nvcc_threads( [ "-DWIN32_LEAN_AND_MEAN", "-O3", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_BFLOAT16_OPERATORS__", "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", "-U__CUDA_NO_BFLOAT162_OPERATORS__", "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", "--expt-relaxed-constexpr", "--expt-extended-lambda", "--use_fast_math", "--ptxas-options=-v", "-lineinfo", ] right ?

ramzeez88 avatar Apr 09 '24 16:04 ramzeez88