tutorials icon indicating copy to clipboard operation
tutorials copied to clipboard

RetinaNet3D not serializing to torchscript

Open AceMcAwesome77 opened this issue 1 year ago • 0 comments

Hi, I have trained a RetinaNet3D model and am trying to serialize it to torchscript so that I can use it with BentoML. However I am having lots of problems trying to serialize it.

The first problem is that RetinaNet doesn't appear to have a __ getstate__ function that models like this usually have. This causes BentoML's built-in serialization function to fail.

So instead, I am trying to serialize it using torch.jit.script. However this fails with the following error:

Traceback (most recent call last):
  File "torchscript_serialize.py", line 54, in <module>
    traced_script_module = torch.jit.script(model, example.float().cpu())
  File "/usr/local/lib/python3.8/dist-packages/torch/jit/_script.py", line 1284, in script
    return torch.jit._recursive.create_script_module(
  File "/usr/local/lib/python3.8/dist-packages/torch/jit/_recursive.py", line 480, in create_script_module
    return create_script_module_impl(nn_module, concrete_type, stubs_fn)
  File "/usr/local/lib/python3.8/dist-packages/torch/jit/_recursive.py", line 542, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/usr/local/lib/python3.8/dist-packages/torch/jit/_script.py", line 614, in _construct
    init_fn(script_module)
  File "/usr/local/lib/python3.8/dist-packages/torch/jit/_recursive.py", line 520, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
  File "/usr/local/lib/python3.8/dist-packages/torch/jit/_recursive.py", line 546, in create_script_module_impl
    create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
  File "/usr/local/lib/python3.8/dist-packages/torch/jit/_recursive.py", line 397, in create_methods_and_properties_from_stubs
    concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
  File "/usr/local/lib/python3.8/dist-packages/torch/jit/_recursive.py", line 867, in try_compile_fn
    return torch.jit.script(fn, _rcb=rcb)
  File "/usr/local/lib/python3.8/dist-packages/torch/jit/_script.py", line 1338, in script
    ast = get_jit_def(obj, obj.__name__)
  File "/usr/local/lib/python3.8/dist-packages/torch/jit/frontend.py", line 297, in get_jit_def
    return build_def(parsed_def.ctx, fn_def, type_line, def_name, self_name=self_name, pdt_arg_types=pdt_arg_types)
  File "/usr/local/lib/python3.8/dist-packages/torch/jit/frontend.py", line 335, in build_def
    param_list = build_param_list(ctx, py_def.args, self_name, pdt_arg_types)
  File "/usr/local/lib/python3.8/dist-packages/torch/jit/frontend.py", line 359, in build_param_list
    raise NotSupportedError(ctx_range, _vararg_kwarg_err)
torch.jit.frontend.NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults:
  File "/usr/local/lib/python3.8/dist-packages/monai/apps/detection/utils/detector_utils.py", line 164
    size_divisible: Union[int, Sequence[int]],
    mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT,
    **kwargs,
     ~~~~~~~ <--- HERE
) -> Tuple[Tensor, List[List[int]]]:

It looks like it can't serialize the **kwargs part of the pad_images() function. I am wondering if the MONAI RetinaNet3D model is a type of model that can be serialized to torchscript? Or is it too complicated, given the runtime functions like padding, inference window sliding, etc.

Here is a full script called "torchscript_serialize.py" that results in the above error. Note that I added custom __ getstate__ functions to a few classes to get past the initial problems with having no __ getstate__ function. In this case, "example_tensor.pt" is a dummy torch tensor I made having the input shape of the model.

import torch

from monai.apps.detection.networks.retinanet_detector import RetinaNetDetector
from monai.apps.detection.networks.retinanet_network import RetinaNet
from monai.apps.detection.utils.anchor_utils import AnchorGeneratorWithAnchorShape

anchor_generator = AnchorGeneratorWithAnchorShape(
    feature_map_scales=[2**l for l in range(len([1,2]) + 1)],
    base_anchor_shapes=[[6,8,4],[8,6,5],[10,10,6]]
)


_example_path = r'example_tensor.pt'
_inference_model_path = r'trained_retinanet_model.pt'    
example = torch.load(_example_path, map_location='cpu')

net = torch.jit.load(_inference_model_path, map_location='cpu')

detector = RetinaNetDetector(network=net, anchor_generator=anchor_generator, debug=False)

def __getstate__(self):
    state = self.__dict__.copy()
    state.pop("_thread_local", None)
    state.pop("_metrics_lock", None)
    return state

RetinaNetDetector.__getstate__ = __getstate__
RetinaNet.__getstate__ = __getstate__
torch.jit._script.RecursiveScriptModule.__getstate__ = __getstate__

patch_size = (192,192,96)

detector.set_box_selector_parameters(
    score_thresh=0.02,
    topk_candidates_per_level=1000,
    nms_thresh=0.22,
    detections_per_img=1,
)
detector.set_sliding_window_inferer(
    roi_size=patch_size,
    overlap=0.25,
    sw_batch_size=1,
    mode="gaussian",
    device="cpu",
    # device=device,
)

detector.eval()

compiled_model = torch.jit.script(detector(example.float().cpu(), use_inferer=True))

AceMcAwesome77 avatar Apr 26 '23 16:04 AceMcAwesome77