extension-cpp icon indicating copy to clipboard operation
extension-cpp copied to clipboard

`TORCH_LIBRARY` and `m.def` Not Working as Documented

Open andylizf opened this issue 9 months ago • 1 comments

I encountered an issue where using TORCH_LIBRARY alone, without the dispatcher API, does not work as expected. According to the PyTorch documentation, the TORCH_LIBRARY macro should create a function that registers custom operators. However, when I follow this approach, I get the following error during runtime:

$ python test/benchmark.py cuda
Traceback (most recent call last):
  File "/home/lizhifei/extension-cpp/test/benchmark.py", line 48, in <module>
    new_h, new_C = LLTM(X, W, b, h, C)
                   ^^^^^^^^^^^^^^^^^^^
  File "/home/lizhifei/miniconda3/envs/extension-cpp/lib/python3.12/site-packages/extension_cpp/ops.py", line 11, in lltm
    return LLTMFunction.apply(input, weights, bias, old_h, old_cell)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lizhifei/miniconda3/envs/extension-cpp/lib/python3.12/site-packages/torch/autograd/function.py", line 598, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lizhifei/miniconda3/envs/extension-cpp/lib/python3.12/site-packages/extension_cpp/ops.py", line 17, in forward
    outputs = torch.ops.extension_cpp.lltm_forward.default(
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lizhifei/miniconda3/envs/extension-cpp/lib/python3.12/site-packages/torch/_ops.py", line 921, in __getattr__
    raise AttributeError(
AttributeError: '_OpNamespace' object has no attribute 'lltm_forward'

Here is a link to my modified repository where this issue can be reproduced: andylizf/extension-cpp.

Could you please help me understand why this is happening and how to resolve it? Thank you.

Environment Information
  • OS: Windows 11 23H2 22631.3527
  • PyTorch version: 2.3.0
  • How you installed PyTorch: conda
  • Python version: 3.12.3
  • CUDA/cuDNN version: CUDA 12.1, cuDNN 8.9.2
  • GPU models and configuration: NVIDIA GeForce RTX 3090
  • Conda Env:
# packages in environment at /home/lizhifei/miniconda3/envs/extension-cpp:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                       2_gnu    conda-forge
blas                      1.0                         mkl    conda-forge
brotli-python             1.1.0           py312h30efb56_1    conda-forge
bzip2                     1.0.8                hd590300_5    conda-forge
ca-certificates           2024.2.2             hbcca054_0    conda-forge
certifi                   2024.2.2           pyhd8ed1ab_0    conda-forge
charset-normalizer        3.3.2              pyhd8ed1ab_0    conda-forge
cuda                      12.1.0                        0    nvidia
cuda-cccl                 12.1.109                      0    nvidia/label/cuda-12.1.1
cuda-command-line-tools   12.1.1                        0    nvidia/label/cuda-12.1.1
cuda-compiler             12.1.1                        0    nvidia/label/cuda-12.1.1
cuda-cudart               12.1.105                      0    nvidia
cuda-cudart-dev           12.1.105                      0    nvidia/label/cuda-12.1.1
cuda-cudart-static        12.1.105                      0    nvidia/label/cuda-12.1.1
cuda-cuobjdump            12.1.111                      0    nvidia/label/cuda-12.1.1
cuda-cupti                12.1.105                      0    nvidia
cuda-cupti-static         12.1.105                      0    nvidia/label/cuda-12.1.1
cuda-cuxxfilt             12.1.105                      0    nvidia/label/cuda-12.1.1
cuda-demo-suite           12.1.105                      0    nvidia/label/cuda-12.1.1
cuda-documentation        12.1.105                      0    nvidia/label/cuda-12.1.1
cuda-driver-dev           12.1.105                      0    nvidia/label/cuda-12.1.1
cuda-gdb                  12.1.105                      0    nvidia/label/cuda-12.1.1
cuda-libraries            12.1.0                        0    nvidia
cuda-libraries-dev        12.1.0                        0    nvidia
cuda-libraries-static     12.1.1                        0    nvidia/label/cuda-12.1.1
cuda-nsight               12.1.105                      0    nvidia/label/cuda-12.1.1
cuda-nsight-compute       12.1.1                        0    nvidia/label/cuda-12.1.1
cuda-nvcc                 12.1.105                      0    nvidia/label/cuda-12.1.1
cuda-nvdisasm             12.1.105                      0    nvidia/label/cuda-12.1.1
cuda-nvml-dev             12.1.105                      0    nvidia/label/cuda-12.1.1
cuda-nvprof               12.1.105                      0    nvidia/label/cuda-12.1.1
cuda-nvprune              12.1.105                      0    nvidia/label/cuda-12.1.1
cuda-nvrtc                12.1.105                      0    nvidia
cuda-nvrtc-dev            12.1.105                      0    nvidia/label/cuda-12.1.1
cuda-nvrtc-static         12.1.105                      0    nvidia/label/cuda-12.1.1
cuda-nvtx                 12.1.105                      0    nvidia
cuda-nvvp                 12.1.105                      0    nvidia/label/cuda-12.1.1
cuda-opencl               12.4.127                      0    nvidia
cuda-opencl-dev           12.1.105                      0    nvidia/label/cuda-12.1.1
cuda-profiler-api         12.1.105                      0    nvidia/label/cuda-12.1.1
cuda-runtime              12.1.0                        0    nvidia
cuda-sanitizer-api        12.1.105                      0    nvidia/label/cuda-12.1.1
cuda-toolkit              12.1.0                        0    nvidia
cuda-tools                12.1.0                        0    nvidia
cuda-version              12.4                 h3060b56_3    conda-forge
cuda-visual-tools         12.1.0                        0    nvidia
extension-cpp             0.0.1                    pypi_0    pypi
ffmpeg                    4.3                  hf484d3e_0    pytorch
filelock                  3.14.0             pyhd8ed1ab_0    conda-forge
freetype                  2.12.1               h267a509_2    conda-forge
fsspec                    2024.3.1                 pypi_0    pypi
gds-tools                 1.6.1.9                       0    nvidia/label/cuda-12.1.1
gmp                       6.3.0                h59595ed_1    conda-forge
gnutls                    3.6.13               h85f3911_1    conda-forge
icu                       73.2                 h59595ed_0    conda-forge
idna                      3.7                pyhd8ed1ab_0    conda-forge
intel-openmp              2023.1.0         hdb19cb5_46306  
jinja2                    3.1.4              pyhd8ed1ab_0    conda-forge
jpeg                      9e                   h166bdaf_2    conda-forge
lame                      3.100             h166bdaf_1003    conda-forge
lcms2                     2.15                 hfd0df8a_0    conda-forge
ld_impl_linux-64          2.40                 h55db66e_0    conda-forge
lerc                      4.0.0                h27087fc_0    conda-forge
libblas                   3.9.0           1_h86c2bf4_netlib    conda-forge
libcblas                  3.9.0           5_h92ddd45_netlib    conda-forge
libcublas                 12.1.0.26                     0    nvidia
libcublas-dev             12.1.0.26                     0    nvidia
libcublas-static          12.4.5.8             hd3aeb46_1    conda-forge
libcufft                  11.0.2.4                      0    nvidia
libcufft-dev              11.0.2.4                      0    nvidia
libcufft-static           11.2.1.3             hd3aeb46_1    conda-forge
libcufile                 1.9.1.3                       0    nvidia
libcufile-dev             1.6.1.9                       0    nvidia/label/cuda-12.1.1
libcufile-static          1.6.1.9                       0    nvidia/label/cuda-12.1.1
libcurand                 10.3.5.147                    0    nvidia
libcurand-dev             10.3.2.106                    0    nvidia/label/cuda-12.1.1
libcurand-static          10.3.2.106                    0    nvidia/label/cuda-12.1.1
libcusolver               11.4.4.55                     0    nvidia
libcusolver-dev           11.4.4.55                     0    nvidia
libcusolver-static        11.6.1.9             hd3aeb46_1    conda-forge
libcusparse               12.0.2.55                     0    nvidia
libcusparse-dev           12.0.2.55                     0    nvidia
libcusparse-static        12.3.1.170           hd3aeb46_1    conda-forge
libdeflate                1.17                 h0b41bf4_0    conda-forge
libexpat                  2.6.2                h59595ed_0    conda-forge
libffi                    3.4.2                h7f98852_5    conda-forge
libgcc-ng                 13.2.0               h77fa898_7    conda-forge
libgfortran-ng            13.2.0               h69a702a_7    conda-forge
libgfortran5              13.2.0               hca663fb_7    conda-forge
libgomp                   13.2.0               h77fa898_7    conda-forge
libhwloc                  2.10.0          default_h2fb2949_1000    conda-forge
libiconv                  1.17                 hd590300_2    conda-forge
libjpeg-turbo             2.0.0                h9bf148f_0    pytorch
liblapack                 3.9.0           5_h92ddd45_netlib    conda-forge
libnpp                    12.0.2.50                     0    nvidia
libnpp-dev                12.0.2.50                     0    nvidia
libnpp-static             12.2.5.30            hd3aeb46_1    conda-forge
libnsl                    2.0.1                hd590300_0    conda-forge
libnvjitlink              12.1.105                      0    nvidia
libnvjitlink-dev          12.1.105                      0    nvidia/label/cuda-12.1.1
libnvjitlink-static       12.4.127             hd3aeb46_1    conda-forge
libnvjpeg                 12.1.1.14                     0    nvidia
libnvjpeg-dev             12.1.1.14                     0    nvidia
libnvjpeg-static          12.3.1.117           ha770c72_1    conda-forge
libnvvm-samples           12.1.105                      0    nvidia/label/cuda-12.1.1
libpng                    1.6.43               h2797004_0    conda-forge
libsqlite                 3.45.3               h2797004_0    conda-forge
libstdcxx-ng              13.2.0               hc0a3c3a_7    conda-forge
libtiff                   4.5.0                h6adf6a1_2    conda-forge
libuuid                   2.38.1               h0b41bf4_0    conda-forge
libwebp-base              1.4.0                hd590300_0    conda-forge
libxcrypt                 4.4.36               hd590300_1    conda-forge
libxml2                   2.12.6               h232c23b_2    conda-forge
libzlib                   1.2.13               hd590300_5    conda-forge
llvm-openmp               15.0.7               h0cdce71_0    conda-forge
markupsafe                2.1.5           py312h98912ed_0    conda-forge
mkl                       2023.1.0         h213fc3f_46344  
mpmath                    1.3.0              pyhd8ed1ab_0    conda-forge
ncurses                   6.5                  h59595ed_0    conda-forge
nettle                    3.6                  he412f7d_0    conda-forge
networkx                  3.3                pyhd8ed1ab_1    conda-forge
ninja                     1.11.1.1                 pypi_0    pypi
nsight-compute            2023.1.1.4                    0    nvidia/label/cuda-12.1.1
numpy                     1.26.4          py312heda63a1_0    conda-forge
openh264                  2.1.1                h780b84a_0    conda-forge
openjpeg                  2.5.0                hfec8fc6_2    conda-forge
openssl                   3.3.0                hd590300_0    conda-forge
pillow                    10.3.0          py312h5eee18b_0  
pip                       24.0               pyhd8ed1ab_0    conda-forge
pysocks                   1.7.1              pyha2e5f31_6    conda-forge
python                    3.12.3          hab00c5b_0_cpython    conda-forge
python_abi                3.12                    4_cp312    conda-forge
pytorch                   2.3.0           py3.12_cuda12.1_cudnn8.9.2_0    pytorch
pytorch-cuda              12.1                 ha16c6d3_5    pytorch
pytorch-mutex             1.0                        cuda    pytorch
pyyaml                    6.0.1           py312h98912ed_1    conda-forge
readline                  8.2                  h8228510_1    conda-forge
requests                  2.31.0             pyhd8ed1ab_0    conda-forge
setuptools                69.5.1             pyhd8ed1ab_0    conda-forge
sympy                     1.12               pyh04b8f61_3    conda-forge
tbb                       2021.12.0            h00ab1b0_0    conda-forge
tk                        8.6.13          noxft_h4845f30_101    conda-forge
torchaudio                2.3.0               py312_cu121    pytorch
torchvision               0.18.0              py312_cu121    pytorch
typing_extensions         4.11.0             pyha770c72_0    conda-forge
tzdata                    2024a                h0c530f3_0    conda-forge
urllib3                   2.2.1              pyhd8ed1ab_0    conda-forge
wheel                     0.43.0             pyhd8ed1ab_1    conda-forge
xz                        5.2.6                h166bdaf_0    conda-forge
yaml                      0.2.5                h7f98852_2    conda-forge
zlib                      1.2.13               hd590300_5    conda-forge
zstd                      1.5.6                ha6fb4c9_0    conda-forge

andylizf avatar May 14 '24 03:05 andylizf