CompressAI
CompressAI copied to clipboard
Request ONNX export support for Cheng2020 model
Feature
Support exporting Cheng2020 model to onnx format.
Motivation
To deploy the model on various hardwares.
Additional context
This is my convertion code:
import torch
from compressai.zoo import models
# net = models["bmshj2018-factorized"](quality=1, metric="mse", pretrained=True)
# net = cheng2020_anchor(quality=5, pretrained=True).to(device)
net = models["cheng2020-anchor"](quality=1, metric="mse", pretrained=True)
# Some dummy input
x = torch.randn(1, 3, 224, 224, requires_grad=True)
# Export the model
torch.onnx.export(net, # model being run
x, # model input (or a tuple for multiple inputs)
"cheng2020.onnx", # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=11, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names = ['input'], # the model's input names
output_names = ['output'], # the model's output names
dynamic_axes={'input': {0 : 'batch_size'}, # variable length axes
'output': {0 : 'batch_size'}}
)
Error occurs:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
[<ipython-input-48-bddb317a9b45>](https://localhost:8080/#) in <cell line: 12>()
10
11 # Export the model
---> 12 torch.onnx.export(net, # model being run
13 x, # model input (or a tuple for multiple inputs)
14 "cheng2020.onnx", # where to save the model (can be a file or file-like object)
15 frames
[/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py](https://localhost:8080/#) in export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, custom_opsets, export_modules_as_functions, autograd_inlining)
514 """
515
--> 516 _export(
517 model,
518 args,
[/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py](https://localhost:8080/#) in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size, custom_opsets, add_node_names, onnx_shape_inference, export_modules_as_functions, autograd_inlining)
1610 _validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
1611
-> 1612 graph, params_dict, torch_out = _model_to_graph(
1613 model,
1614 args,
[/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py](https://localhost:8080/#) in _model_to_graph(model, args, verbose, input_names, output_names, operator_export_type, do_constant_folding, _disable_torch_constant_prop, fixed_batch_size, training, dynamic_axes)
1132
1133 model = _pre_trace_quant_model(model, args)
-> 1134 graph, params, torch_out, module = _create_jit_graph(model, args)
1135 params_dict = _get_named_param_dict(graph, params)
1136
[/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py](https://localhost:8080/#) in _create_jit_graph(model, args)
1008 return graph, params, torch_out, None
1009
-> 1010 graph, torch_out = _trace_and_get_graph_from_model(model, args)
1011 _C._jit_pass_onnx_lint(graph)
1012 state_dict = torch.jit._unique_state_dict(model)
[/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py](https://localhost:8080/#) in _trace_and_get_graph_from_model(model, args)
912 prev_autocast_cache_enabled = torch.is_autocast_cache_enabled()
913 torch.set_autocast_cache_enabled(False)
--> 914 trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
915 model,
916 args,
[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py](https://localhost:8080/#) in _fn(*args, **kwargs)
449 prior = set_eval_frame(callback)
450 try:
--> 451 return fn(*args, **kwargs)
452 finally:
453 set_eval_frame(prior)
[/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py](https://localhost:8080/#) in inner(*args, **kwargs)
34 @functools.wraps(fn)
35 def inner(*args, **kwargs):
---> 36 return fn(*args, **kwargs)
37
38 return inner
[/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py](https://localhost:8080/#) in _get_trace_graph(f, args, kwargs, strict, _force_outplace, return_inputs, _return_inputs_states)
1308 if not isinstance(args, tuple):
1309 args = (args,)
-> 1310 outs = ONNXTracedModule(
1311 f, strict, _force_outplace, return_inputs, _return_inputs_states
1312 )(*args, **kwargs)
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
1533
1534 def _call_impl(self, *args, **kwargs):
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1542
1543 try:
[/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py](https://localhost:8080/#) in forward(self, *args)
136 return tuple(out_vars)
137
--> 138 graph, out = torch._C._create_graph_by_tracing(
139 wrapper,
140 in_vars + module_state,
[/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py](https://localhost:8080/#) in wrapper(*args)
127 if self._return_inputs_states:
128 inputs_states.append(_unflatten(in_args, in_desc))
--> 129 outs.append(self.inner(*trace_inputs))
130 if self._return_inputs_states:
131 inputs_states[0] = (inputs_states[0], trace_inputs)
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
1533
1534 def _call_impl(self, *args, **kwargs):
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1542
1543 try:
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _slow_forward(self, *input, **kwargs)
1520 recording_scopes = False
1521 try:
-> 1522 result = self.forward(*input, **kwargs)
1523 finally:
1524 if recording_scopes:
[/usr/local/lib/python3.10/dist-packages/compressai/models/google.py](https://localhost:8080/#) in forward(self, x)
543 ctx_params = self.context_prediction(y_hat)
544 gaussian_params = self.entropy_parameters(
--> 545 torch.cat((params, ctx_params), dim=1)
546 )
547 scales_hat, means_hat = gaussian_params.chunk(2, 1)
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 16 but got size 14 for tensor number 1 in the list.
Related ISSUE #87.
Hello! I met the same issue and it seems that the reason is due to the ONNX lib rather than the compressai lib. I am wondering if you fixed this issue?