RevCol
RevCol copied to clipboard
How to export onnx model in save_memory=True?
We are trying to convert Revcol to TensorRT format, but when converting to ONNX, we found that when using save_memory=True, the conversion does not work properly.
Here is our conversion test code:
import torch
from models.revcol import *
model = revcol_tiny(save_memory=True, inter_supv=False, drop_path = 0.1, num_classes=10, kernel_size = 3)
for i in range(model.num_subnet):
getattr(model, f'subnet{str(i)}').save_memory = False
x = torch.zeros(1, 3, 224, 224)
torch.onnx.export(model, x, './weights/revcol_tiny.onnx', verbose=False, opset_version=17,
training=torch.onnx.TrainingMode.EVAL,
do_constant_folding=True,
input_names=['images'],
output_names=['output'],
dynamic_axes=None)
When save_memory=True, the following error occurs:
File [d:\SoftWare\anaconda3\envs\torch\lib\site-packages\torch\onnx\utils.py:506](file:///D:/SoftWare/anaconda3/envs/torch/lib/site-packages/torch/onnx/utils.py:506), in export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, custom_opsets, export_modules_as_functions)
188 @_beartype.beartype
189 def export(
190 model: Union[torch.nn.Module, torch.jit.ScriptModule, torch.jit.ScriptFunction],
(...)
206 export_modules_as_functions: Union[bool, Collection[Type[torch.nn.Module]]] = False,
207 ) -> None:
208 r"""Exports a model into ONNX format.
209
210 If ``model`` is not a :class:`torch.jit.ScriptModule` nor a
(...)
503 All errors are subclasses of :class:`errors.OnnxExporterError`.
...
511 '(vmap, grad, jvp, jacrev, ...), it must override the setup_context '
512 'staticmethod. For more details, please see '
513 'https://pytorch.org/docs/master/notes/extending.func.html')
RuntimeError: invalid unordered_map<K, T> key
If you add the following code, the export will work, but you should not be able to take advantage of the low memory footprint of Reversible Net.
for i in range(model.num_subnet):
getattr(model, f'subnet{str(i)}').save_memory = False
Is there any relevant solution?
Do you use ONNX in inference? You can set save_memory = False when converting the weight, then set save_memory = True in later inference. Low memory footprint only benefits the training process.