GTS icon indicating copy to clipboard operation
GTS copied to clipboard

Question about Gumbel sampling

Open ThinkNaive opened this issue 3 years ago • 1 comments

I read that you apply a bivariate gumbel sampling in your paper, and use the generalized form gumbel softmax. Gumbel softmax takes logits (log probability) as input, while you directly use learned structure theta as input: adj = gumbel_softmax(x, temperature=temp, hard=True) (in line 234, GTS/model/pytorch/model.py) Why it worked? Thank you.

ThinkNaive avatar Jun 03 '21 05:06 ThinkNaive

Hi, thanks for your great question. Here we considered the output of neural network as logits. This implementation is the same as the NRI code. In addition, we'd like to provide another option. We could also use the following way to get the logits:

x = torch.nn.softmax(x) logits = torch.log(x+1e-20)

If you have additional questions, please let me know. : )

chaoshangcs avatar Jun 07 '21 15:06 chaoshangcs