tensorboardX
tensorboardX copied to clipboard
Graph Visualization not supported for models with non standard inputs ( Dictionaries , Tuples)
Hi, I have been trying to visualize graphs for Pytorch models that require dictionaries as input to model or tuples. I would like to know if somebody has ever suceeded or found a trick to make it work. I would like to do the pull request myself but the method torch.jit.get_trace_graph seems to be the limitation here as it only takes Tensors, Tuples as inputs. Any help is welcome . Many thanks
class RNN2(nn.Module): def init(self, input_size, hidden_size,output_size): super(RNN2, self).init() self.hidden_size = hidden_size self.i2h = nn.Linear(n_categories + input_size + hidden_size, hidden_size) self.i2o = nn.Linear(n_categories + input_size + hidden_size, output_size) self.o2o = nn.Linear(hidden_size + output_size, output_size) self.dropout = nn.Dropout(0.1) self.softmax = nn.LogSoftmax(dim=1)
def forward(self,data={}):
#def forward(self,category,input,hidden):
category=data['cat']
input=data['input']
hidden=data['h']
input_combined = torch.cat((category, input, hidden), 1)
hidden = self.i2h(input_combined)
output = self.i2o(input_combined)
output_combined = torch.cat((hidden, output), 1)
output = self.o2o(output_combined)
output = self.dropout(output)
output = self.softmax(output)
return output, hidden
def initHidden(self):
return torch.zeros(1, self.hidden_size)
if name == "main": print("Hello main")
n_letters = 100
n_hidden = 128
n_categories = 10
rnn = RNN2(n_letters, n_hidden, n_categories)
cat = torch.Tensor(1, n_categories)
dummy_input = torch.Tensor(1, n_letters)
hidden = rnn.initHidden()
data={'cat':cat,'input':dummy_input,'h':hidden}
with SummaryWriter(comment='RNN2') as w:
w.add_graph(rnn,data, verbose=False)
I don't have a solution yet, but I might have encountered a similar problem. When trying to do writer.add_graph(model, input)
, I get output like:
/Users/zamparol/anaconda36/lib/python3.6/site-packages/torch/onnx/utils.py:365: UserWarning: ONNX export failed on ATen operator norm because torch.onnx.symbolic.norm does not exist .format(op_name, op_name)) /Users/zamparol/anaconda36/lib/python3.6/site-packages/torch/onnx/utils.py:365: UserWarning: ONNX export failed on ATen operator sort because torch.onnx.symbolic.sort does not exist .format(op_name, op_name)) /Users/zamparol/anaconda36/lib/python3.6/site-packages/torch/onnx/utils.py:365: UserWarning: ONNX export failed on ATen operator gather because torch.onnx.symbolic.gather does not exist .format(op_name, op_name))
My custom K-max pooling layer breaks the parsing:
class DynamicKmaxPooling(nn.Module):
def __init__(self, k_top, L, l):
"""
k_top := smallest possible number of pooled elements
L := number of convolutional layers in the network
l := the index of this pooling layer in the network """
super(DynamicKmaxPooling, self).__init__()
self.k_top = k_top
self.L = L
self.l = l
def forward(self, x, dim=2):
s = x.size()[2]
k_ll = ((self.L - self.l) / self.L) * s
pool_size = round(max(self.k_top, int(np.ceil(k_ll))))
index = x.topk(pool_size, dim)[1].sort(dim)[0]
return x.gather(dim, index)
You have an operator in you model that ONNX does not recognize. That is why you get this error message. Adding this operator is feasible but quite a tricky process. What type is x in your model? Tuple? tensor?
X is a tensor ( topk, sort, gather are all Tensor methods). For the purposes of visualizing the network, it would be sufficient to substitute a pooling layer. Would that be easier than implementing a new operator?
Because X is a tensor, the classic way to visualize should work for you ( get_trace -> get graph from trace -> send graph to tensorboard ) . This method does not require to convert your model to a ONNX model. ONNX is only a back up method here, and I don't see why your model would fail on this method. However If your model fails all methods, you can still use that : https://github.com/szagoruyko/pytorchviz it is really robust
This method does not require to convert your model to a ONNX model. ONNX is only a back up method here, and I don't see why your model would fail on this method.
Thanks for the background, I'll step through execution more carefully to see where the classic way fails and leads to attempting the ONNX backup. Also thanks for the pointer to pytorchviz!