GTS
GTS copied to clipboard
Question about Gumbel sampling
I read that you apply a bivariate gumbel sampling in your paper, and use the generalized form gumbel softmax.
Gumbel softmax takes logits (log probability) as input, while you directly use learned structure theta as input:
adj = gumbel_softmax(x, temperature=temp, hard=True)
(in line 234, GTS/model/pytorch/model.py)
Why it worked? Thank you.
Hi, thanks for your great question. Here we considered the output of neural network as logits. This implementation is the same as the NRI code. In addition, we'd like to provide another option. We could also use the following way to get the logits:
x = torch.nn.softmax(x) logits = torch.log(x+1e-20)
If you have additional questions, please let me know. : )