graphtrans icon indicating copy to clipboard operation
graphtrans copied to clipboard

GraphTrans-mean may be wrong

Open LUOyk1999 opened this issue 2 years ago • 0 comments

Hello, Thanks for excellent work. But I have found some possible problems. In the paper, the authors mention that "In Table 5, we tested several common methods to for sequence classification. The mean operation averages the output embeddings of the transformer to a single graph embedding; the last operation takes the last embedding in the output sequence as the graph embedding."

Table 5: Model, Valid, Test GraphTrans-mean, 0.1398, 0.1509

However, I observed the GraphTrans code and found that the author's implementation of mean could be wrong. gnn_transformer.py, line 116-117: elif self.pooling == "mean": h_graph = transformer_out.sum(0) / src_padding_mask.sum(-1, keepdim=True) transformer_out.shape = (S, B, h_d), src_padding_mask.shape = (B, S) The padding nodes information in transformer_out, and the authors do not remove them (do not unpad_batch) but sum directly.

I modified the mean operation, then redid the experiment and found that the result is improved than the one reported by the authors.

LUOyk1999 avatar Oct 09 '22 08:10 LUOyk1999