RevCol icon indicating copy to clipboard operation
RevCol copied to clipboard

How to export onnx model in save_memory=True?

Open BearCooike opened this issue 2 years ago • 1 comments

We are trying to convert Revcol to TensorRT format, but when converting to ONNX, we found that when using save_memory=True, the conversion does not work properly. Here is our conversion test code:

import torch
from models.revcol import *
model = revcol_tiny(save_memory=True, inter_supv=False, drop_path = 0.1, num_classes=10, kernel_size = 3)

for i in range(model.num_subnet):
    getattr(model, f'subnet{str(i)}').save_memory = False

x = torch.zeros(1, 3, 224, 224)
torch.onnx.export(model, x, './weights/revcol_tiny.onnx', verbose=False, opset_version=17,
                        training=torch.onnx.TrainingMode.EVAL,
                        do_constant_folding=True,
                        input_names=['images'],
                        output_names=['output'],
                        dynamic_axes=None) 

When save_memory=True, the following error occurs:

File [d:\SoftWare\anaconda3\envs\torch\lib\site-packages\torch\onnx\utils.py:506](file:///D:/SoftWare/anaconda3/envs/torch/lib/site-packages/torch/onnx/utils.py:506), in export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, custom_opsets, export_modules_as_functions)
    188 @_beartype.beartype
    189 def export(
    190     model: Union[torch.nn.Module, torch.jit.ScriptModule, torch.jit.ScriptFunction],
   (...)
    206     export_modules_as_functions: Union[bool, Collection[Type[torch.nn.Module]]] = False,
    207 ) -> None:
    208     r"""Exports a model into ONNX format.
    209 
    210     If ``model`` is not a :class:`torch.jit.ScriptModule` nor a
   (...)
    503             All errors are subclasses of :class:`errors.OnnxExporterError`.
...
    511         '(vmap, grad, jvp, jacrev, ...), it must override the setup_context '
    512         'staticmethod. For more details, please see '
    513         'https://pytorch.org/docs/master/notes/extending.func.html')

RuntimeError: invalid unordered_map<K, T> key

If you add the following code, the export will work, but you should not be able to take advantage of the low memory footprint of Reversible Net.

for i in range(model.num_subnet):
    getattr(model, f'subnet{str(i)}').save_memory = False

Is there any relevant solution?

BearCooike avatar Dec 20 '23 07:12 BearCooike

Do you use ONNX in inference? You can set save_memory = False when converting the weight, then set save_memory = True in later inference. Low memory footprint only benefits the training process.

nightsnack avatar Dec 21 '23 03:12 nightsnack