torchlayers icon indicating copy to clipboard operation
torchlayers copied to clipboard

Cannot Infer Shapes from pretrained Models.

Open mycpuorg opened this issue 2 years ago • 0 comments

Hi, I have a relatively straightforward situation where I need to validate my input shapes while jit.loading a saved model and I can't seem to find the solution for.

save a traced model and verify inputs shapes:
test_model = torchvision.models.resnet.resnet18(pretrained=True).eval()
with torch.no_grad():
    inp_224 = torch.rand(1, 3, 224, 224, dtype=torch.float)
    script_module_224 = torch.jit.trace(test_model, inp_224)
    graph_inputs = list(script_module_224.graph.inputs())
    graph_inputs = graph_inputs[1:] # ignore self
    print(graph_inputs[0].type().sizes()) # [1, 3, 224, 224]
    script_module_224.save("saved_model_224.pt")

load the same model cannot validate the traced input shapes:

# ... later elsewhere load my saved model
loaded_224 = torch.jit.load("saved_model_224.pt")
# there's nothing preventing me from sending incorrect input shapes
# i.e., traced with 224 but called with 500
inp_500 = torch.rand(1, 3, 500, 500, dtype=torch.float)
loaded_224(inp_500)

I'd like to prevent feeding incorrect shaped inputs after loading models. In particular, I would like to be able to do something like:

traced_input_shape = loaded_224.get_input_shape()
if (inp_500.shape() != traced_input_shape()):
    print("Error: Trying to run Inference with Incorrect Shaped Inputs!")
    # die

I tried using torchlayers to help with this situation by:

import torchlayers as tl
t.build(loaded_224, inp_224)

This failed (reasonably) with:

PickleError: ScriptModules cannot be deepcopied using copy.deepcopy or saved using torch.save. Mixed serialization of script and non-script modules is not supported. For purely script modules use my_script_module.save(<filename>) instead.

Any recommendations?

mycpuorg avatar Apr 03 '22 17:04 mycpuorg