Pytorch-Correlation-extension
Pytorch-Correlation-extension copied to clipboard
Error when JIT save
When trying to save as JIT, it doesn't work.
Sample minimum code:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
def forward(self, input1, input2):
return spatial_correlation_sample(input1, input2, kernel_size=3, patch_size=1, stride=2, padding=0, dilation=2, dilation_patch=1)
net = Net()
trace_model = torch.jit.trace(net, [input1, input2])
torch.jit.save(trace_model, "trace_model.pt") # this doesn't work
The error is:
RuntimeError:
Could not export Python function call 'SpatialCorrelationSamplerFunction'. Remove calls to Python functions before export. Did you forget to add @script or @script_method annotation? If this is a nn.ModuleList, add it to __constants__:
/usr/local/lib/python3.7/dist-packages/spatial_correlation_sampler/spatial_correlation_sampler.py(42): spatial_correlation_sample
Any help or suggestions are appreciated!
Full error message:
1 frames
/usr/local/lib/python3.7/dist-packages/torch/jit/_script.py in save(self, f, **kwargs)
594 See :func:torch.jit.save <torch.jit.save> for details.
595 """
--> 596 return self._c.save(str(f), **kwargs)
597
598 def _save_for_lite_interpreter(self, *args, **kwargs):
RuntimeError:
Could not export Python function call 'SpatialCorrelationSamplerFunction'. Remove calls to Python functions before export. Did you forget to add @script or @script_method annotation? If this is a nn.ModuleList, add it to constants:
/usr/local/lib/python3.7/dist-packages/spatial_correlation_sampler/spatial_correlation_sampler.py(42): spatial_correlation_sample
</details>
There is an issue from another model that says saving and loading without jit helped, but it didn't work for me. https://github.com/mileyan/AnyNet/issues/34
Got the same error when using traced_model.save().
RuntimeError:
Could not export Python function call 'SpatialCorrelationSamplerFunction'. Remove calls to Python functions before export. Did you forget add @script or @script_method annotation? If this is a nn.ModuleList, add it to __constants__:
/opt/conda/lib/python3.7/site-packages/spatial_correlation_sampler/spatial_correlation_sampler.py(42): spatial_correlation_sample
/data/repositories/PWCNet/PyTorch/PWC_net.py(44): Correlation
/data/repositories/PWCNet/PyTorch/PWC_net.py(212): forward
/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py(516): _slow_forward
/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py(530): __call__
/opt/conda/lib/python3.7/site-packages/torch/jit/__init__.py(1034): trace_module
/opt/conda/lib/python3.7/site-packages/torch/jit/__init__.py(882): trace
/tmp/ipykernel_2475/1985920461.py(1): <module>
/opt/conda/lib/python3.7/site-packages/IPython/core/interactiveshell.py(3441): run_code
/opt/conda/lib/python3.7/site-packages/IPython/core/interactiveshell.py(3361): run_ast_nodes
/opt/conda/lib/python3.7/site-packages/IPython/core/interactiveshell.py(3170): run_cell_async
/opt/conda/lib/python3.7/site-packages/IPython/core/async_helpers.py(68): _pseudo_sync_runner
/opt/conda/lib/python3.7/site-packages/IPython/core/interactiveshell.py(2944): _run_cell
/opt/conda/lib/python3.7/site-packages/IPython/core/interactiveshell.py(2899): run_cell
/opt/conda/lib/python3.7/site-packages/ipykernel/zmqshell.py(532): run_cell
/opt/conda/lib/python3.7/site-packages/ipykernel/ipkernel.py(335): do_execute
/opt/conda/lib/python3.7/site-packages/ipykernel/kernelbase.py(647): execute_request
/opt/conda/lib/python3.7/site-packages/ipykernel/kernelbase.py(352): dispatch_shell
/opt/conda/lib/python3.7/site-packages/ipykernel/kernelbase.py(445): process_one
/opt/conda/lib/python3.7/site-packages/ipykernel/kernelbase.py(456): dispatch_queue
/opt/conda/lib/python3.7/asyncio/events.py(88): _run
/opt/conda/lib/python3.7/asyncio/base_events.py(1771): _run_once
/opt/conda/lib/python3.7/asyncio/base_events.py(534): run_forever
/opt/conda/lib/python3.7/site-packages/tornado/platform/asyncio.py(199): start
/opt/conda/lib/python3.7/site-packages/ipykernel/kernelapp.py(668): start
/opt/conda/lib/python3.7/site-packages/traitlets/config/application.py(664): launch_instance
/opt/conda/lib/python3.7/site-packages/ipykernel_launcher.py(16): <module>
/opt/conda/lib/python3.7/runpy.py(85): _run_code
/opt/conda/lib/python3.7/runpy.py(193): _run_module_as_main
Any solutions?
Hi, I do believe that extension using the CUDAExtension class API is deprecated and not compatible with the Jit saver. I still maintain this package so that it's usable with last pytorch version, but it's not meant to be used on modern architectures.
If you want to use this CUDA code in your jit compiled model, you might want to use more modern alternatives like https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html