ancestral-gumbel-top-k-sampling icon indicating copy to clipboard operation
ancestral-gumbel-top-k-sampling copied to clipboard

How to extend to sampling categorical values?

Open Jordy-VL opened this issue 4 years ago • 2 comments

Dear authors,

I tested the code in your notebook, which is very clearly explained, thanks for that!

I now would like to extend this to sampling from a repeated (sequence) of categorical variables. So each node in the probabilistic directed acyclic graph has K possible choices, not 2.

I tried this by altering the domain_size to K**L and the model generated is still correct, yet not the graph. Currently, it generates bernoulli samples: tensor([1, 0, 1, 1, 0], dtype=torch.uint8).

Could you give any hint on how to change your code to support multi-class? :)

Thanks in advance,

Jordy

Jordy-VL avatar Apr 02 '21 11:04 Jordy-VL

Hi!

Thanks for the interest in our paper! This notebook is mainly for illustrative purposes on the algorithm, which explains why we used a very simple model which is not trivial to extend to > 2 classes (as it uses just 1 parameter to indicate a Bernoulli distribution). However, given your own model, the code should work out of the box with K > 2 classes. I'll try to give you some hints.

Short summary

It is completely up to you how to define your model. You must ensure that you replace below two lines, such that for a given set of variables already sampled S with values y[S], you return, for the next variable v a length K vector log_p_y_cond with (normalized!) log_probabilities (i.e. torch.logsumexp(log_p_y_cond, -1) == 0).

conditional_config_index = binary2long(y * graph[v])
            log_p_y_cond = log_probs[v, conditional_config_index]

Long answer

A Bernoulli (2 class categorical) can be represented using a single probability/success parameter. If there are N nodes, the model variable is a [N,2^N] matrix represents for each of the N nodes the conditional Bernoulli probability given all 2^N possible configurations. model = torch.rand(num_nodes, domain_size) * alpha[:, None] + (prior * (1 - alpha))[:, None]

The model is transformed into log_probs for all K=2 classes, which is then [N,2^N,2]: log_probs = torch.stack([1 - model, model], -1).log()

Many of these are not actually used if a node does not depend on all other N-1 nodes, which is achieved by setting the independent variables to 0 (by multiplying with the mask/graph) when looking up the conditional probabilities for a given partial configuration y in model:

conditional_config_index = binary2long(y * graph[v])
log_p_y_cond = log_probs[v, conditional_config_index]

Now if you want to generalize this to K > 2, you should make sure log_probs has shape [N,K^N,K], which is probably easiest by directly generating a model of this size, e.g. using a softmax of some random values, rather than a single Bernoulli probability for each conditional distribution. Then the binary2long should be k-ary to long which, given that bitshifting is multiplying by 2^i, should probably be something like this (untested!):

def kary2long(m): # converts binary to long
    return (m.long() * (k ** torch.arange(m.size(-1), out=m.new().long()))).sum(-1)

Note that this may blow up as the total number of cases increases exponentially. Therefore, it is better to have your model defined in some functional form.

wouterkool avatar Apr 06 '21 16:04 wouterkool

Hi Wouter,

Thanks for the great exposition! Btw, for clarity, I asked the question since I am working on defining calibration for structured prediction problems. As you mention, these problems suffer from an exponential-sized probability simplex, which is why I am looking into different sampling approaches. Your approach seems a perfect fit since it can be applied to arbitrary structures, as long as it can be translated into a probabilistic directed acyclic graph. :)

In the meanwhile, I also found this functional model form for generating and sampling unique sequences: sequence example.

I might still have some more questions, once I'm able to make more progress.

Dank je wel ;)

Jordy

Jordy-VL avatar Apr 06 '21 17:04 Jordy-VL