hiddenlayer
hiddenlayer copied to clipboard
Shape information is not plotted
def plot_network(self, model, hltrans, input_size = [1, 1, 36]):
graph = hl.build_graph(model, torch.zeros(input_size).double())
#model = torchvision.models.vgg16()
#graph = hl.build_graph(model, torch.zeros([1, 3, 224, 224]))
dot=graph.build_dot()
dot.format="png"
im=dot.render(cleanup=True)
net_img=plt.imread(im)
I have a custom pytorch model, which is plotted with the function above. If I use my model, the shape information is not plotted on the lines connecting the blocks. If I uncomment the vgg16 line, and overwrite the model, the plot contains the shape information.
Can you help me pinpoint the difference? Here is the two output images (cropped).
@paland3 : have you solved this problem? if yes, could you please share the way you solve. @waleedka can you please help us
This solution works for me. I hope this helps.