pytorch-tree-lstm
pytorch-tree-lstm copied to clipboard
_label_node_index is giving different nodes same label
_label_node_index function in example_usage.py is giving different nodes same label because of the recursion. I think it should be something like below code to retain the last value assigned to a node:
n=0 def _label_node_index(node): global n node['index'] = n for child in node['children']: n += 1 _label_node_index(child)
I just ran into the same issue with duplicate node indexes. My (in my opinion a little bit more elegant) solution, not requiring a global variable:
def _label_node_index(node, n=0): node['index'] = n for child in node['children']: n += 1 n = _label_node_index(child, n) return n