torchlayers
torchlayers copied to clipboard
Cannot Infer Shapes from pretrained Models.
Hi,
I have a relatively straightforward situation where I need to validate my input shapes while jit.load
ing a saved model and I can't seem to find the solution for.
save a traced model and verify inputs shapes:
test_model = torchvision.models.resnet.resnet18(pretrained=True).eval()
with torch.no_grad():
inp_224 = torch.rand(1, 3, 224, 224, dtype=torch.float)
script_module_224 = torch.jit.trace(test_model, inp_224)
graph_inputs = list(script_module_224.graph.inputs())
graph_inputs = graph_inputs[1:] # ignore self
print(graph_inputs[0].type().sizes()) # [1, 3, 224, 224]
script_module_224.save("saved_model_224.pt")
load the same model cannot validate the traced input shapes:
# ... later elsewhere load my saved model
loaded_224 = torch.jit.load("saved_model_224.pt")
# there's nothing preventing me from sending incorrect input shapes
# i.e., traced with 224 but called with 500
inp_500 = torch.rand(1, 3, 500, 500, dtype=torch.float)
loaded_224(inp_500)
I'd like to prevent feeding incorrect shaped inputs after loading models. In particular, I would like to be able to do something like:
traced_input_shape = loaded_224.get_input_shape()
if (inp_500.shape() != traced_input_shape()):
print("Error: Trying to run Inference with Incorrect Shaped Inputs!")
# die
I tried using torchlayers
to help with this situation by:
import torchlayers as tl
t.build(loaded_224, inp_224)
This failed (reasonably) with:
PickleError: ScriptModules cannot be deepcopied using copy.deepcopy or saved using torch.save. Mixed serialization of script and non-script modules is not supported. For purely script modules use my_script_module.save(<filename>) instead.
Any recommendations?