TreeLSTMSentiment icon indicating copy to clipboard operation
TreeLSTMSentiment copied to clipboard

Doubts about the BInaryTreeLSTM implementation

Open ksolaiman opened this issue 6 years ago • 7 comments

In your implementation of the BinaryTreeLSTM,

  1. Can you explain the leaf / base condition - self.ox just passes through a linear layer, that is understandable since there is no hidden state for leafs, but why is there no weight params for input, update or forget gating, and why is cell state just passed through a linear layer?

  2. And also for non-leaf nodes, you are completely ignoring passing the input through a linear layer, for all the gating units. Is there an explanation for that? In ChildSum, you have weight parameters for x_j, why not in n-ary lstm ?

self.ix = nn.Linear(self.in_dim,self.mem_dim)

ksolaiman avatar Jul 17 '18 22:07 ksolaiman

I follow the implementation in https://github.com/stanfordnlp/treelstm and successfully reproduce result in paper https://arxiv.org/abs/1503.00075 for Binary Classification.

ttpro1995 avatar Jul 18 '18 02:07 ttpro1995

Here is the original implementation of author I am trying to transfer 1:1 https://github.com/stanfordnlp/treelstm/blob/master/models/BinaryTreeLSTM.lua

ttpro1995 avatar Jul 18 '18 02:07 ttpro1995

I will check again later, I did this for a long time and hardly remember a thing :3 . But I the meantime, please help me do the math. I think I have already confirm that the formula match in the paper.

ttpro1995 avatar Jul 18 '18 02:07 ttpro1995

"and why is cell state just passed through a linear layer"

In paper https://arxiv.org/abs/1503.00075 , the liner layer is mention as "softmax classifier" (section 4.1) and "sentiment classifier" (section 5.3)

ttpro1995 avatar Jul 18 '18 02:07 ttpro1995

I understand you did this a long time ago and it reproduces the result in the paper, but I have to tweak it for my problem domain, so I am trying to understand everything under the hood.

Let's just think about the leaf first. In leaf nodes, there is no child/hidden states, so no f_j's, but what happens to i_j & u_j, and as a result c_j ? The second term in below equation is 0 here, understandable, but what happens to the rest? image

Is it the case that for the leaf cells, theoretically - there is nothing to forget and update, so no u_j and f_j ? And i_j is just the word/input passed through a linear layer? I am pretty new to this field, so I am trying to understand. Any help is appreciated.

ksolaiman avatar Jul 18 '18 02:07 ksolaiman

I think because leaf node connect to embedding layer, it does not have child. So please look at https://github.com/ttpro1995/TreeLSTMSentiment/blob/master/model.py

Inner node (line 97)

    def forward(self, lc, lh , rc, rh):
        i = F.sigmoid(self.ilh(lh) + self.irh(rh))
        lf = F.sigmoid(self.lflh(lh) + self.lfrh(rh))
        rf = F.sigmoid(self.rflh(lh) + self.rfrh(rh))
        update = F.tanh(self.ulh(lh) + self.urh(rh))
        c =  i* update + lf*lc + rf*rc
        h = F.tanh(c)
        return c, h

and leaf node (line 34)

    def forward(self, input):
        c = self.cx(input)
        o = F.sigmoid(self.ox(input))
        h = o * F.tanh(c)
        return c, h

You may see that at inner node, these formular need lh, lc and rh, rc (which is c, h state for left child, and right child)

In leaf node, there is no left child and right child, so these thing lh, lc, rh, rc disappear :3 We only have input, which is embedding layer

here is where I call leaf node (line 151)

        if tree.num_children == 0:
            # leaf case
            tree.state = self.leaf_module.forward(embs[tree.idx-1])

ttpro1995 avatar Jul 18 '18 03:07 ttpro1995

Yes. And for the non-leaf nodes there are no W*x_j for any kind of gates, is it because for constituency parsers, they are not passing any word vectors as inputs, so that part is 0? image

But if say, I want to do it for dependency parser, where each node in the tree takes the vector corresponding to the head word as input, there would be W*x_j / another linear layer pass of the input (in pytorch specifically)?

ksolaiman avatar Jul 18 '18 03:07 ksolaiman