Torch-Pruning icon indicating copy to clipboard operation
Torch-Pruning copied to clipboard

Can't build dependency graph for model with multiple inputs.

Open ketan4373 opened this issue 3 years ago • 1 comments

I have a model which takes 2 inputs, image and embeddings. Here is a simple inputs that I have

in1 = torch.rand(size=(1, 3, 256, 256))  
in2 = torch.rand(size=(512, 1))
out = model(in1, in2)

This is how I am passing 2 inputs. Now, in building dependency, here is what I've tried,

strategy = tp.strategy.L1Strategy()
example_inputs=(Xt, embeds)

DG = tp.DependencyGraph()
DG.build_dependency(G, example_inputs=example_inputs)

Also, for ONNX and tensorflow, I've also faced the same problem but I've solved it with "*" ahead of inputs.

out = G(*example_inputs) This works.

But, DG.build_dependency(G, example_inputs=*example_inputs) this gives an error. Let me know if something is unclear.

ketan4373 avatar Mar 09 '21 11:03 ketan4373