RVRT
RVRT copied to clipboard
Dependency issue in using deform_attn for RVRT
There seems to be a dependency issue of cuda kernel while using deform_attn while training if the current dependencies were installed.
RuntimeError: Error building extension 'deform_attn': [1/2] c++ -MMD -MF deform_attn_cuda_pt110.o.d -DTORCH_EXTENSION_NAME=deform_attn -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /opt/conda/lib/python3.7/site-packages/torch/include -isystem /opt/conda/lib/python3.7/site-packages/torch/include/torch/csrc/api/include -isystem /opt/conda/lib/python3.7/site-packages/torch/include/TH -isystem /opt/conda/lib/python3.7/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /opt/conda/include/python3.7m -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++14 -c /home/jeyamariajose/projects/RVRT/KAIR/models/op/deform_attn_cuda_pt110.cpp -o deform_attn_cuda_pt110.o
FAILED: deform_attn_cuda_pt110.o
c++ -MMD -MF deform_attn_cuda_pt110.o.d -DTORCH_EXTENSION_NAME=deform_attn -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /opt/conda/lib/python3.7/site-packages/torch/include -isystem /opt/conda/lib/python3.7/site-packages/torch/include/torch/csrc/api/include -isystem /opt/conda/lib/python3.7/site-packages/torch/include/TH -isystem /opt/conda/lib/python3.7/site-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /opt/conda/include/python3.7m -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++14 -c /home/jeyamariajose/projects/RVRT/KAIR/models/op/deform_attn_cuda_pt110.cpp -o deform_attn_cuda_pt110.o
/home/jeyamariajose/projects/RVRT/KAIR/models/op/deform_attn_cuda_pt110.cpp: In function ‘void deform_attn_cuda_backward(at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, int, int, int, int, int, int, int, int, int, int, int)’:
/home/jeyamariajose/projects/RVRT/KAIR/models/op/deform_attn_cuda_pt110.cpp:187:68: error: invalid initialization of reference of type ‘const at::Tensor&’ from expression of type ‘const c10::ScalarType’
grad_attns = at::_softmax_backward_data(grad_attns, attns, -1, dtype);
^~~~~
In file included from /opt/conda/lib/python3.7/site-packages/torch/include/ATen/ATen.h:15,
from /opt/conda/lib/python3.7/site-packages/torch/include/torch/csrc/api/include/torch/types.h:3,
from /opt/conda/lib/python3.7/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader_options.h:4,
from /opt/conda/lib/python3.7/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/base.h:3,
from /opt/conda/lib/python3.7/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader/stateful.h:3,
from /opt/conda/lib/python3.7/site-packages/torch/include/torch/csrc/api/include/torch/data/dataloader.h:3,
from /opt/conda/lib/python3.7/site-packages/torch/include/torch/csrc/api/include/torch/data.h:3,
from /opt/conda/lib/python3.7/site-packages/torch/include/torch/csrc/api/include/torch/all.h:8,
from /opt/conda/lib/python3.7/site-packages/torch/include/torch/extension.h:4,
from /home/jeyamariajose/projects/RVRT/KAIR/models/op/deform_attn_cuda_pt110.cpp:4:
/opt/conda/lib/python3.7/site-packages/torch/include/ATen/Functions.h:5243:29: note: in passing argument 4 of ‘at::Tensor at::_softmax_backward_data(const at::Tensor&, const at::Tensor&, int64_t, const at::Tensor&)’
TORCH_API inline at::Tensor _softmax_backward_data(const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, const at::Tensor & self) {
^~~~~~~~~~~~~~~~~~~~~~
ninja: build stopped: subcommand failed.
I had the same problem.
Which cuda and pytorch versions do you use? I currently use CUDA11.3+Pytorch1.12
.
which gcc version do you use? @JingyunLiang
Did you solve this problem? I also got the same error. @jeya-maria-jose
pip3 install torch==1.9.1+cu111 torchvision==0.10.1+cu111 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html
These versions solved the dependency issue for me.