extension-cpp
extension-cpp copied to clipboard
`TORCH_LIBRARY` and `m.def` Not Working as Documented
I encountered an issue where using TORCH_LIBRARY
alone, without the dispatcher API, does not work as expected. According to the PyTorch documentation, the TORCH_LIBRARY
macro should create a function that registers custom operators. However, when I follow this approach, I get the following error during runtime:
$ python test/benchmark.py cuda
Traceback (most recent call last):
File "/home/lizhifei/extension-cpp/test/benchmark.py", line 48, in <module>
new_h, new_C = LLTM(X, W, b, h, C)
^^^^^^^^^^^^^^^^^^^
File "/home/lizhifei/miniconda3/envs/extension-cpp/lib/python3.12/site-packages/extension_cpp/ops.py", line 11, in lltm
return LLTMFunction.apply(input, weights, bias, old_h, old_cell)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/lizhifei/miniconda3/envs/extension-cpp/lib/python3.12/site-packages/torch/autograd/function.py", line 598, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/lizhifei/miniconda3/envs/extension-cpp/lib/python3.12/site-packages/extension_cpp/ops.py", line 17, in forward
outputs = torch.ops.extension_cpp.lltm_forward.default(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/lizhifei/miniconda3/envs/extension-cpp/lib/python3.12/site-packages/torch/_ops.py", line 921, in __getattr__
raise AttributeError(
AttributeError: '_OpNamespace' object has no attribute 'lltm_forward'
Here is a link to my modified repository where this issue can be reproduced: andylizf/extension-cpp.
Could you please help me understand why this is happening and how to resolve it? Thank you.
Environment Information
- OS: Windows 11 23H2 22631.3527
- PyTorch version: 2.3.0
- How you installed PyTorch: conda
- Python version: 3.12.3
- CUDA/cuDNN version: CUDA 12.1, cuDNN 8.9.2
- GPU models and configuration: NVIDIA GeForce RTX 3090
- Conda Env:
# packages in environment at /home/lizhifei/miniconda3/envs/extension-cpp:
#
# Name Version Build Channel
_libgcc_mutex 0.1 conda_forge conda-forge
_openmp_mutex 4.5 2_gnu conda-forge
blas 1.0 mkl conda-forge
brotli-python 1.1.0 py312h30efb56_1 conda-forge
bzip2 1.0.8 hd590300_5 conda-forge
ca-certificates 2024.2.2 hbcca054_0 conda-forge
certifi 2024.2.2 pyhd8ed1ab_0 conda-forge
charset-normalizer 3.3.2 pyhd8ed1ab_0 conda-forge
cuda 12.1.0 0 nvidia
cuda-cccl 12.1.109 0 nvidia/label/cuda-12.1.1
cuda-command-line-tools 12.1.1 0 nvidia/label/cuda-12.1.1
cuda-compiler 12.1.1 0 nvidia/label/cuda-12.1.1
cuda-cudart 12.1.105 0 nvidia
cuda-cudart-dev 12.1.105 0 nvidia/label/cuda-12.1.1
cuda-cudart-static 12.1.105 0 nvidia/label/cuda-12.1.1
cuda-cuobjdump 12.1.111 0 nvidia/label/cuda-12.1.1
cuda-cupti 12.1.105 0 nvidia
cuda-cupti-static 12.1.105 0 nvidia/label/cuda-12.1.1
cuda-cuxxfilt 12.1.105 0 nvidia/label/cuda-12.1.1
cuda-demo-suite 12.1.105 0 nvidia/label/cuda-12.1.1
cuda-documentation 12.1.105 0 nvidia/label/cuda-12.1.1
cuda-driver-dev 12.1.105 0 nvidia/label/cuda-12.1.1
cuda-gdb 12.1.105 0 nvidia/label/cuda-12.1.1
cuda-libraries 12.1.0 0 nvidia
cuda-libraries-dev 12.1.0 0 nvidia
cuda-libraries-static 12.1.1 0 nvidia/label/cuda-12.1.1
cuda-nsight 12.1.105 0 nvidia/label/cuda-12.1.1
cuda-nsight-compute 12.1.1 0 nvidia/label/cuda-12.1.1
cuda-nvcc 12.1.105 0 nvidia/label/cuda-12.1.1
cuda-nvdisasm 12.1.105 0 nvidia/label/cuda-12.1.1
cuda-nvml-dev 12.1.105 0 nvidia/label/cuda-12.1.1
cuda-nvprof 12.1.105 0 nvidia/label/cuda-12.1.1
cuda-nvprune 12.1.105 0 nvidia/label/cuda-12.1.1
cuda-nvrtc 12.1.105 0 nvidia
cuda-nvrtc-dev 12.1.105 0 nvidia/label/cuda-12.1.1
cuda-nvrtc-static 12.1.105 0 nvidia/label/cuda-12.1.1
cuda-nvtx 12.1.105 0 nvidia
cuda-nvvp 12.1.105 0 nvidia/label/cuda-12.1.1
cuda-opencl 12.4.127 0 nvidia
cuda-opencl-dev 12.1.105 0 nvidia/label/cuda-12.1.1
cuda-profiler-api 12.1.105 0 nvidia/label/cuda-12.1.1
cuda-runtime 12.1.0 0 nvidia
cuda-sanitizer-api 12.1.105 0 nvidia/label/cuda-12.1.1
cuda-toolkit 12.1.0 0 nvidia
cuda-tools 12.1.0 0 nvidia
cuda-version 12.4 h3060b56_3 conda-forge
cuda-visual-tools 12.1.0 0 nvidia
extension-cpp 0.0.1 pypi_0 pypi
ffmpeg 4.3 hf484d3e_0 pytorch
filelock 3.14.0 pyhd8ed1ab_0 conda-forge
freetype 2.12.1 h267a509_2 conda-forge
fsspec 2024.3.1 pypi_0 pypi
gds-tools 1.6.1.9 0 nvidia/label/cuda-12.1.1
gmp 6.3.0 h59595ed_1 conda-forge
gnutls 3.6.13 h85f3911_1 conda-forge
icu 73.2 h59595ed_0 conda-forge
idna 3.7 pyhd8ed1ab_0 conda-forge
intel-openmp 2023.1.0 hdb19cb5_46306
jinja2 3.1.4 pyhd8ed1ab_0 conda-forge
jpeg 9e h166bdaf_2 conda-forge
lame 3.100 h166bdaf_1003 conda-forge
lcms2 2.15 hfd0df8a_0 conda-forge
ld_impl_linux-64 2.40 h55db66e_0 conda-forge
lerc 4.0.0 h27087fc_0 conda-forge
libblas 3.9.0 1_h86c2bf4_netlib conda-forge
libcblas 3.9.0 5_h92ddd45_netlib conda-forge
libcublas 12.1.0.26 0 nvidia
libcublas-dev 12.1.0.26 0 nvidia
libcublas-static 12.4.5.8 hd3aeb46_1 conda-forge
libcufft 11.0.2.4 0 nvidia
libcufft-dev 11.0.2.4 0 nvidia
libcufft-static 11.2.1.3 hd3aeb46_1 conda-forge
libcufile 1.9.1.3 0 nvidia
libcufile-dev 1.6.1.9 0 nvidia/label/cuda-12.1.1
libcufile-static 1.6.1.9 0 nvidia/label/cuda-12.1.1
libcurand 10.3.5.147 0 nvidia
libcurand-dev 10.3.2.106 0 nvidia/label/cuda-12.1.1
libcurand-static 10.3.2.106 0 nvidia/label/cuda-12.1.1
libcusolver 11.4.4.55 0 nvidia
libcusolver-dev 11.4.4.55 0 nvidia
libcusolver-static 11.6.1.9 hd3aeb46_1 conda-forge
libcusparse 12.0.2.55 0 nvidia
libcusparse-dev 12.0.2.55 0 nvidia
libcusparse-static 12.3.1.170 hd3aeb46_1 conda-forge
libdeflate 1.17 h0b41bf4_0 conda-forge
libexpat 2.6.2 h59595ed_0 conda-forge
libffi 3.4.2 h7f98852_5 conda-forge
libgcc-ng 13.2.0 h77fa898_7 conda-forge
libgfortran-ng 13.2.0 h69a702a_7 conda-forge
libgfortran5 13.2.0 hca663fb_7 conda-forge
libgomp 13.2.0 h77fa898_7 conda-forge
libhwloc 2.10.0 default_h2fb2949_1000 conda-forge
libiconv 1.17 hd590300_2 conda-forge
libjpeg-turbo 2.0.0 h9bf148f_0 pytorch
liblapack 3.9.0 5_h92ddd45_netlib conda-forge
libnpp 12.0.2.50 0 nvidia
libnpp-dev 12.0.2.50 0 nvidia
libnpp-static 12.2.5.30 hd3aeb46_1 conda-forge
libnsl 2.0.1 hd590300_0 conda-forge
libnvjitlink 12.1.105 0 nvidia
libnvjitlink-dev 12.1.105 0 nvidia/label/cuda-12.1.1
libnvjitlink-static 12.4.127 hd3aeb46_1 conda-forge
libnvjpeg 12.1.1.14 0 nvidia
libnvjpeg-dev 12.1.1.14 0 nvidia
libnvjpeg-static 12.3.1.117 ha770c72_1 conda-forge
libnvvm-samples 12.1.105 0 nvidia/label/cuda-12.1.1
libpng 1.6.43 h2797004_0 conda-forge
libsqlite 3.45.3 h2797004_0 conda-forge
libstdcxx-ng 13.2.0 hc0a3c3a_7 conda-forge
libtiff 4.5.0 h6adf6a1_2 conda-forge
libuuid 2.38.1 h0b41bf4_0 conda-forge
libwebp-base 1.4.0 hd590300_0 conda-forge
libxcrypt 4.4.36 hd590300_1 conda-forge
libxml2 2.12.6 h232c23b_2 conda-forge
libzlib 1.2.13 hd590300_5 conda-forge
llvm-openmp 15.0.7 h0cdce71_0 conda-forge
markupsafe 2.1.5 py312h98912ed_0 conda-forge
mkl 2023.1.0 h213fc3f_46344
mpmath 1.3.0 pyhd8ed1ab_0 conda-forge
ncurses 6.5 h59595ed_0 conda-forge
nettle 3.6 he412f7d_0 conda-forge
networkx 3.3 pyhd8ed1ab_1 conda-forge
ninja 1.11.1.1 pypi_0 pypi
nsight-compute 2023.1.1.4 0 nvidia/label/cuda-12.1.1
numpy 1.26.4 py312heda63a1_0 conda-forge
openh264 2.1.1 h780b84a_0 conda-forge
openjpeg 2.5.0 hfec8fc6_2 conda-forge
openssl 3.3.0 hd590300_0 conda-forge
pillow 10.3.0 py312h5eee18b_0
pip 24.0 pyhd8ed1ab_0 conda-forge
pysocks 1.7.1 pyha2e5f31_6 conda-forge
python 3.12.3 hab00c5b_0_cpython conda-forge
python_abi 3.12 4_cp312 conda-forge
pytorch 2.3.0 py3.12_cuda12.1_cudnn8.9.2_0 pytorch
pytorch-cuda 12.1 ha16c6d3_5 pytorch
pytorch-mutex 1.0 cuda pytorch
pyyaml 6.0.1 py312h98912ed_1 conda-forge
readline 8.2 h8228510_1 conda-forge
requests 2.31.0 pyhd8ed1ab_0 conda-forge
setuptools 69.5.1 pyhd8ed1ab_0 conda-forge
sympy 1.12 pyh04b8f61_3 conda-forge
tbb 2021.12.0 h00ab1b0_0 conda-forge
tk 8.6.13 noxft_h4845f30_101 conda-forge
torchaudio 2.3.0 py312_cu121 pytorch
torchvision 0.18.0 py312_cu121 pytorch
typing_extensions 4.11.0 pyha770c72_0 conda-forge
tzdata 2024a h0c530f3_0 conda-forge
urllib3 2.2.1 pyhd8ed1ab_0 conda-forge
wheel 0.43.0 pyhd8ed1ab_1 conda-forge
xz 5.2.6 h166bdaf_0 conda-forge
yaml 0.2.5 h7f98852_2 conda-forge
zlib 1.2.13 hd590300_5 conda-forge
zstd 1.5.6 ha6fb4c9_0 conda-forge