Torch-Pruning
Torch-Pruning copied to clipboard
Can't build dependency graph for model with multiple inputs.
I have a model which takes 2 inputs, image and embeddings. Here is a simple inputs that I have
in1 = torch.rand(size=(1, 3, 256, 256))
in2 = torch.rand(size=(512, 1))
out = model(in1, in2)
This is how I am passing 2 inputs. Now, in building dependency, here is what I've tried,
strategy = tp.strategy.L1Strategy()
example_inputs=(Xt, embeds)
DG = tp.DependencyGraph()
DG.build_dependency(G, example_inputs=example_inputs)
Also, for ONNX and tensorflow, I've also faced the same problem but I've solved it with "*" ahead of inputs.
out = G(*example_inputs)
This works.
But, DG.build_dependency(G, example_inputs=*example_inputs)
this gives an error. Let me know if something is unclear.