tutorials
tutorials copied to clipboard
RetinaNet3D not serializing to torchscript
Hi, I have trained a RetinaNet3D model and am trying to serialize it to torchscript so that I can use it with BentoML. However I am having lots of problems trying to serialize it.
The first problem is that RetinaNet doesn't appear to have a __ getstate__ function that models like this usually have. This causes BentoML's built-in serialization function to fail.
So instead, I am trying to serialize it using torch.jit.script. However this fails with the following error:
Traceback (most recent call last):
File "torchscript_serialize.py", line 54, in <module>
traced_script_module = torch.jit.script(model, example.float().cpu())
File "/usr/local/lib/python3.8/dist-packages/torch/jit/_script.py", line 1284, in script
return torch.jit._recursive.create_script_module(
File "/usr/local/lib/python3.8/dist-packages/torch/jit/_recursive.py", line 480, in create_script_module
return create_script_module_impl(nn_module, concrete_type, stubs_fn)
File "/usr/local/lib/python3.8/dist-packages/torch/jit/_recursive.py", line 542, in create_script_module_impl
script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
File "/usr/local/lib/python3.8/dist-packages/torch/jit/_script.py", line 614, in _construct
init_fn(script_module)
File "/usr/local/lib/python3.8/dist-packages/torch/jit/_recursive.py", line 520, in init_fn
scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
File "/usr/local/lib/python3.8/dist-packages/torch/jit/_recursive.py", line 546, in create_script_module_impl
create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
File "/usr/local/lib/python3.8/dist-packages/torch/jit/_recursive.py", line 397, in create_methods_and_properties_from_stubs
concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
File "/usr/local/lib/python3.8/dist-packages/torch/jit/_recursive.py", line 867, in try_compile_fn
return torch.jit.script(fn, _rcb=rcb)
File "/usr/local/lib/python3.8/dist-packages/torch/jit/_script.py", line 1338, in script
ast = get_jit_def(obj, obj.__name__)
File "/usr/local/lib/python3.8/dist-packages/torch/jit/frontend.py", line 297, in get_jit_def
return build_def(parsed_def.ctx, fn_def, type_line, def_name, self_name=self_name, pdt_arg_types=pdt_arg_types)
File "/usr/local/lib/python3.8/dist-packages/torch/jit/frontend.py", line 335, in build_def
param_list = build_param_list(ctx, py_def.args, self_name, pdt_arg_types)
File "/usr/local/lib/python3.8/dist-packages/torch/jit/frontend.py", line 359, in build_param_list
raise NotSupportedError(ctx_range, _vararg_kwarg_err)
torch.jit.frontend.NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults:
File "/usr/local/lib/python3.8/dist-packages/monai/apps/detection/utils/detector_utils.py", line 164
size_divisible: Union[int, Sequence[int]],
mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT,
**kwargs,
~~~~~~~ <--- HERE
) -> Tuple[Tensor, List[List[int]]]:
It looks like it can't serialize the **kwargs part of the pad_images() function. I am wondering if the MONAI RetinaNet3D model is a type of model that can be serialized to torchscript? Or is it too complicated, given the runtime functions like padding, inference window sliding, etc.
Here is a full script called "torchscript_serialize.py" that results in the above error. Note that I added custom __ getstate__ functions to a few classes to get past the initial problems with having no __ getstate__ function. In this case, "example_tensor.pt" is a dummy torch tensor I made having the input shape of the model.
import torch
from monai.apps.detection.networks.retinanet_detector import RetinaNetDetector
from monai.apps.detection.networks.retinanet_network import RetinaNet
from monai.apps.detection.utils.anchor_utils import AnchorGeneratorWithAnchorShape
anchor_generator = AnchorGeneratorWithAnchorShape(
feature_map_scales=[2**l for l in range(len([1,2]) + 1)],
base_anchor_shapes=[[6,8,4],[8,6,5],[10,10,6]]
)
_example_path = r'example_tensor.pt'
_inference_model_path = r'trained_retinanet_model.pt'
example = torch.load(_example_path, map_location='cpu')
net = torch.jit.load(_inference_model_path, map_location='cpu')
detector = RetinaNetDetector(network=net, anchor_generator=anchor_generator, debug=False)
def __getstate__(self):
state = self.__dict__.copy()
state.pop("_thread_local", None)
state.pop("_metrics_lock", None)
return state
RetinaNetDetector.__getstate__ = __getstate__
RetinaNet.__getstate__ = __getstate__
torch.jit._script.RecursiveScriptModule.__getstate__ = __getstate__
patch_size = (192,192,96)
detector.set_box_selector_parameters(
score_thresh=0.02,
topk_candidates_per_level=1000,
nms_thresh=0.22,
detections_per_img=1,
)
detector.set_sliding_window_inferer(
roi_size=patch_size,
overlap=0.25,
sw_batch_size=1,
mode="gaussian",
device="cpu",
# device=device,
)
detector.eval()
compiled_model = torch.jit.script(detector(example.float().cpu(), use_inferer=True))