tensorboardX icon indicating copy to clipboard operation
tensorboardX copied to clipboard

Graph Visualization not supported for models with non standard inputs ( Dictionaries , Tuples)

Open vincentalbouy opened this issue 6 years ago • 6 comments

Hi, I have been trying to visualize graphs for Pytorch models that require dictionaries as input to model or tuples. I would like to know if somebody has ever suceeded or found a trick to make it work. I would like to do the pull request myself but the method torch.jit.get_trace_graph seems to be the limitation here as it only takes Tensors, Tuples as inputs. Any help is welcome . Many thanks

vincentalbouy avatar Sep 18 '18 05:09 vincentalbouy

class RNN2(nn.Module): def init(self, input_size, hidden_size,output_size): super(RNN2, self).init() self.hidden_size = hidden_size self.i2h = nn.Linear(n_categories + input_size + hidden_size, hidden_size) self.i2o = nn.Linear(n_categories + input_size + hidden_size, output_size) self.o2o = nn.Linear(hidden_size + output_size, output_size) self.dropout = nn.Dropout(0.1) self.softmax = nn.LogSoftmax(dim=1)

def forward(self,data={}):
#def forward(self,category,input,hidden):  
    category=data['cat']
    input=data['input']
    hidden=data['h']
    input_combined = torch.cat((category, input, hidden), 1)
    hidden = self.i2h(input_combined)
    output = self.i2o(input_combined)
    output_combined = torch.cat((hidden, output), 1)
    output = self.o2o(output_combined)
    output = self.dropout(output)
    output = self.softmax(output)
    return output, hidden

def initHidden(self):
    return torch.zeros(1, self.hidden_size)

if name == "main": print("Hello main")

    n_letters = 100
    n_hidden = 128
    n_categories = 10
    rnn = RNN2(n_letters, n_hidden, n_categories)
    cat = torch.Tensor(1, n_categories)
    dummy_input = torch.Tensor(1, n_letters)
    hidden = rnn.initHidden()

    data={'cat':cat,'input':dummy_input,'h':hidden}
    with SummaryWriter(comment='RNN2') as w:
          w.add_graph(rnn,data, verbose=False)

vincentalbouy avatar Sep 18 '18 05:09 vincentalbouy

I don't have a solution yet, but I might have encountered a similar problem. When trying to do writer.add_graph(model, input), I get output like:

/Users/zamparol/anaconda36/lib/python3.6/site-packages/torch/onnx/utils.py:365: UserWarning: ONNX export failed on ATen operator norm because torch.onnx.symbolic.norm does not exist .format(op_name, op_name)) /Users/zamparol/anaconda36/lib/python3.6/site-packages/torch/onnx/utils.py:365: UserWarning: ONNX export failed on ATen operator sort because torch.onnx.symbolic.sort does not exist .format(op_name, op_name)) /Users/zamparol/anaconda36/lib/python3.6/site-packages/torch/onnx/utils.py:365: UserWarning: ONNX export failed on ATen operator gather because torch.onnx.symbolic.gather does not exist .format(op_name, op_name))

My custom K-max pooling layer breaks the parsing:

class DynamicKmaxPooling(nn.Module):
    def __init__(self, k_top, L, l):
        """ 
        k_top := smallest possible number of pooled elements 
        L := number of convolutional layers in the network
        l := the index of this pooling layer in the network """
        super(DynamicKmaxPooling, self).__init__()
        self.k_top = k_top
        self.L = L
        self.l = l
        
    def forward(self, x, dim=2):    
        s = x.size()[2]
        k_ll = ((self.L - self.l) / self.L) * s
        pool_size = round(max(self.k_top, int(np.ceil(k_ll))))
        index = x.topk(pool_size, dim)[1].sort(dim)[0]
        return x.gather(dim, index)

lzamparo avatar Oct 01 '18 22:10 lzamparo

You have an operator in you model that ONNX does not recognize. That is why you get this error message. Adding this operator is feasible but quite a tricky process. What type is x in your model? Tuple? tensor?

vincentalbouy avatar Oct 01 '18 22:10 vincentalbouy

X is a tensor ( topk, sort, gather are all Tensor methods). For the purposes of visualizing the network, it would be sufficient to substitute a pooling layer. Would that be easier than implementing a new operator?

lzamparo avatar Oct 02 '18 14:10 lzamparo

Because X is a tensor, the classic way to visualize should work for you ( get_trace -> get graph from trace -> send graph to tensorboard ) . This method does not require to convert your model to a ONNX model. ONNX is only a back up method here, and I don't see why your model would fail on this method. However If your model fails all methods, you can still use that : https://github.com/szagoruyko/pytorchviz it is really robust

vincentalbouy avatar Oct 02 '18 20:10 vincentalbouy

This method does not require to convert your model to a ONNX model. ONNX is only a back up method here, and I don't see why your model would fail on this method.

Thanks for the background, I'll step through execution more carefully to see where the classic way fails and leads to attempting the ONNX backup. Also thanks for the pointer to pytorchviz!

lzamparo avatar Oct 05 '18 15:10 lzamparo