treelstm.pytorch icon indicating copy to clipboard operation
treelstm.pytorch copied to clipboard

Does current TreeLSTM support batch size?

Open jinfengr opened this issue 6 years ago • 0 comments

It seems batch size is still not supported from the code? In the forward function of ChildSumTreeLSTM, it seems that it only support process a single tree in one forward.

`

 def forward(self, tree, inputs):
    for idx in range(tree.num_children):
        self.forward(tree.children[idx], inputs)

    if tree.num_children == 0:
        child_c = inputs[0].detach().new(1, self.mem_dim).fill_(0.).requires_grad_()
        child_h = inputs[0].detach().new(1, self.mem_dim).fill_(0.).requires_grad_()
    else:
        child_c, child_h = zip(* map(lambda x: x.state, tree.children))
        child_c, child_h = torch.cat(child_c, dim=0), torch.cat(child_h, dim=0)

    tree.state = self.node_forward(inputs[tree.idx], child_c, child_h)
    return tree.state

`

jinfengr avatar Jan 16 '19 22:01 jinfengr