RWKV-LM
RWKV-LM copied to clipboard
CUDA compilation error with Ctx Length>2000
Hello, I am trying out RWKV with audio modality and when I set T_MAX>>1000, it throws this error:
Emitting ninja build file /root/.cache/torch_extensions/py39_cu116/timex/build.ninja...
Building extension module timex...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
[1/2] /usr/local/cuda/bin/nvcc -DTORCH_EXTENSION_NAME=timex -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1013\" -isystem /root/anaconda3/envs/surya-env/lib/python3.9/site-packages/torch/include -isystem /root/anaconda3/envs/surya-env/lib/python3.9/site-packages/torch/include/torch/csrc/api/include -isystem /root/anaconda3/envs/surya-env/lib/python3.9/site-packages/torch/include/TH -isystem /root/anaconda3/envs/surya-env/lib/python3.9/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /root/anaconda3/envs/surya-env/include/python3.9 -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_80,code=sm_80 --compiler-options '-fPIC' --use_fast_math --extra-device-vectorization -DTmax=10000 -DBF=8 -DBB=2 -std=c++14 -c cuda/timex_cuda.cu -o timex_cuda.cuda.o
FAILED: timex_cuda.cuda.o
/usr/local/cuda/bin/nvcc -DTORCH_EXTENSION_NAME=timex -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1013\" -isystem /root/anaconda3/envs/surya-env/lib/python3.9/site-packages/torch/include -isystem /root/anaconda3/envs/surya-env/lib/python3.9/site-packages/torch/include/torch/csrc/api/include -isystem /root/anaconda3/envs/surya-env/lib/python3.9/site-packages/torch/include/TH -isystem /root/anaconda3/envs/surya-env/lib/python3.9/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /root/anaconda3/envs/surya-env/include/python3.9 -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_80,code=sm_80 --compiler-options '-fPIC' --use_fast_math --extra-device-vectorization -DTmax=10000 -DBF=8 -DBB=2 -std=c++14 -c cuda/timex_cuda.cu -o timex_cuda.cuda.o
ptxas error : Entry function '_Z15kernel_backwardIfEvPKT_S2_S2_PS0_S3_iii' uses too much shared data (0x30d40 bytes, 0xc000 max)
ptxas error : Entry function '_Z14kernel_forwardIfEvPKT_S2_PS0_S0_iii' uses too much shared data (0x57e40 bytes, 0xc000 max)
ninja: build stopped: subcommand failed.
GPU: A100, VRAM: 42GB, CUDA 11.6
I am okay if the training takes a bit long. But I need this to work. Don't know any CUDA. Can you suggest some workarounds?
Thanks for the incredible work btw!
Also, it seems FP16 doesn't work out-of-the-box. Could you suggest changes to make it work?
Hello, I am trying out RWKV with audio modality and when I set T_MAX>>1000, it throws this error:
Reduce B_GROUP_FORWARD and B_GROUP_BACKWARD.
Also, it seems FP16 doesn't work out-of-the-box. Could you suggest changes to make it work?
You can move the FFN to FP16 first :)
I did that, but now it gives Illegal memory accessed
at k.continguous()
in the forward of the TimeMix. Works fine in fp32.
I did that, but now it gives
Illegal memory accessed
atk.continguous()
in the forward of the TimeMix. Works fine in fp32.
The CUDA code assumes a tensor element to be 4 bytes. So it's only good for fp32.
@BlinkDL Can you please point out where we need to make a change to the code to reduce the tensor element from 4 bytes to 2 bytes? Thanks a lot!
@BlinkDL Can you please point out where we need to make a change to the code to reduce the tensor element from 4 bytes to 2 bytes? Thanks a lot!
And the current design will overflow under FP16 :) Wait for my new kernels.
Now the new RWKV-4 can compile ctxlen=4096 kernels :)