Pytorch-Correlation-extension icon indicating copy to clipboard operation
Pytorch-Correlation-extension copied to clipboard

Error when JIT save

Open qqpann opened this issue 4 years ago • 4 comments

When trying to save as JIT, it doesn't work.

Sample minimum code:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

    def forward(self, input1, input2):
        return spatial_correlation_sample(input1, input2, kernel_size=3, patch_size=1, stride=2, padding=0, dilation=2, dilation_patch=1)

net = Net()
trace_model = torch.jit.trace(net, [input1, input2])

torch.jit.save(trace_model, "trace_model.pt") # this doesn't work

The error is:

RuntimeError: 
Could not export Python function call 'SpatialCorrelationSamplerFunction'. Remove calls to Python functions before export. Did you forget to add @script or @script_method annotation? If this is a nn.ModuleList, add it to __constants__:
/usr/local/lib/python3.7/dist-packages/spatial_correlation_sampler/spatial_correlation_sampler.py(42): spatial_correlation_sample

Any help or suggestions are appreciated!

Full error message:

```text --------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) in () ----> 1 torch.jit.save(trace_model, "trace_model.pt")

1 frames /usr/local/lib/python3.7/dist-packages/torch/jit/_script.py in save(self, f, **kwargs) 594 See :func:torch.jit.save <torch.jit.save> for details. 595 """ --> 596 return self._c.save(str(f), **kwargs) 597 598 def _save_for_lite_interpreter(self, *args, **kwargs):

RuntimeError: Could not export Python function call 'SpatialCorrelationSamplerFunction'. Remove calls to Python functions before export. Did you forget to add @script or @script_method annotation? If this is a nn.ModuleList, add it to constants: /usr/local/lib/python3.7/dist-packages/spatial_correlation_sampler/spatial_correlation_sampler.py(42): spatial_correlation_sample (18): forward /usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py(1039): _slow_forward /usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py(1051): _call_impl /usr/local/lib/python3.7/dist-packages/torch/jit/_trace.py(959): trace_module /usr/local/lib/python3.7/dist-packages/torch/jit/_trace.py(744): trace (2): /usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py(2882): run_code /usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py(2822): run_ast_nodes /usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py(2718): run_cell /usr/local/lib/python3.7/dist-packages/ipykernel/zmqshell.py(537): run_cell /usr/local/lib/python3.7/dist-packages/ipykernel/ipkernel.py(208): do_execute /usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py(399): execute_request /usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py(233): dispatch_shell /usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py(283): dispatcher /usr/local/lib/python3.7/dist-packages/tornado/stack_context.py(300): null_wrapper /usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py(434): _run_callback /usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py(480): _handle_recv /usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py(451): _handle_events /usr/local/lib/python3.7/dist-packages/tornado/stack_context.py(300): null_wrapper /usr/local/lib/python3.7/dist-packages/tornado/platform/asyncio.py(122): _handle_events /usr/lib/python3.7/asyncio/events.py(88): _run /usr/lib/python3.7/asyncio/base_events.py(1786): _run_once /usr/lib/python3.7/asyncio/base_events.py(541): run_forever /usr/local/lib/python3.7/dist-packages/tornado/platform/asyncio.py(132): start /usr/local/lib/python3.7/dist-packages/ipykernel/kernelapp.py(499): start /usr/local/lib/python3.7/dist-packages/traitlets/config/application.py(845): launch_instance /usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py(16): /usr/lib/python3.7/runpy.py(85): _run_code /usr/lib/python3.7/runpy.py(193): _run_module_as_main

</details>

qqpann avatar Jun 20 '21 05:06 qqpann

There is an issue from another model that says saving and loading without jit helped, but it didn't work for me. https://github.com/mileyan/AnyNet/issues/34

qqpann avatar Jun 20 '21 05:06 qqpann

Got the same error when using traced_model.save().

RuntimeError: 
Could not export Python function call 'SpatialCorrelationSamplerFunction'. Remove calls to Python functions before export. Did you forget add @script or @script_method annotation? If this is a nn.ModuleList, add it to __constants__:
/opt/conda/lib/python3.7/site-packages/spatial_correlation_sampler/spatial_correlation_sampler.py(42): spatial_correlation_sample
/data/repositories/PWCNet/PyTorch/PWC_net.py(44): Correlation
/data/repositories/PWCNet/PyTorch/PWC_net.py(212): forward
/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py(516): _slow_forward
/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py(530): __call__
/opt/conda/lib/python3.7/site-packages/torch/jit/__init__.py(1034): trace_module
/opt/conda/lib/python3.7/site-packages/torch/jit/__init__.py(882): trace
/tmp/ipykernel_2475/1985920461.py(1): <module>
/opt/conda/lib/python3.7/site-packages/IPython/core/interactiveshell.py(3441): run_code
/opt/conda/lib/python3.7/site-packages/IPython/core/interactiveshell.py(3361): run_ast_nodes
/opt/conda/lib/python3.7/site-packages/IPython/core/interactiveshell.py(3170): run_cell_async
/opt/conda/lib/python3.7/site-packages/IPython/core/async_helpers.py(68): _pseudo_sync_runner
/opt/conda/lib/python3.7/site-packages/IPython/core/interactiveshell.py(2944): _run_cell
/opt/conda/lib/python3.7/site-packages/IPython/core/interactiveshell.py(2899): run_cell
/opt/conda/lib/python3.7/site-packages/ipykernel/zmqshell.py(532): run_cell
/opt/conda/lib/python3.7/site-packages/ipykernel/ipkernel.py(335): do_execute
/opt/conda/lib/python3.7/site-packages/ipykernel/kernelbase.py(647): execute_request
/opt/conda/lib/python3.7/site-packages/ipykernel/kernelbase.py(352): dispatch_shell
/opt/conda/lib/python3.7/site-packages/ipykernel/kernelbase.py(445): process_one
/opt/conda/lib/python3.7/site-packages/ipykernel/kernelbase.py(456): dispatch_queue
/opt/conda/lib/python3.7/asyncio/events.py(88): _run
/opt/conda/lib/python3.7/asyncio/base_events.py(1771): _run_once
/opt/conda/lib/python3.7/asyncio/base_events.py(534): run_forever
/opt/conda/lib/python3.7/site-packages/tornado/platform/asyncio.py(199): start
/opt/conda/lib/python3.7/site-packages/ipykernel/kernelapp.py(668): start
/opt/conda/lib/python3.7/site-packages/traitlets/config/application.py(664): launch_instance
/opt/conda/lib/python3.7/site-packages/ipykernel_launcher.py(16): <module>
/opt/conda/lib/python3.7/runpy.py(85): _run_code
/opt/conda/lib/python3.7/runpy.py(193): _run_module_as_main

xing-w avatar Jul 14 '21 03:07 xing-w

Any solutions?

DaskiSnow avatar Mar 12 '24 10:03 DaskiSnow

Hi, I do believe that extension using the CUDAExtension class API is deprecated and not compatible with the Jit saver. I still maintain this package so that it's usable with last pytorch version, but it's not meant to be used on modern architectures.

If you want to use this CUDA code in your jit compiled model, you might want to use more modern alternatives like https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html

ClementPinard avatar Mar 12 '24 20:03 ClementPinard