pyknos icon indicating copy to clipboard operation
pyknos copied to clipboard

Desirable MDN features for SNPE-C / SNPE-A

Open michaeldeistler opened this issue 3 years ago • 1 comments

After having implemented non-atomic SNPE-C, I am writing this issue to keep track of things that would have been desirable to exist in mdn.

Get mixture components

The only non-protected methods of mdn are log_prob() and sample(). It would be great to have a non-protected get_mixture_components. Unlike the already existing, protected _get_mixture_components(), it should also call the embedding_net.

Evaluating the log_prob

The following code should be put in a separate static method called evaluate_mixture_log_prob():

batch_size, n_mixtures, output_dim = means.size()
inputs = inputs.view(-1, 1, output_dim)

# Split up evaluation into parts.
a = logits - torch.logsumexp(logits, dim=-1, keepdim=True)
b = -(output_dim / 2.0) * np.log(2 * np.pi)
c = sumlogdiag
d1 = (inputs.expand_as(means) - means).view(
     batch_size, n_mixtures, output_dim, 1
)
d2 = torch.matmul(precisions, d1)
d = -0.5 * torch.matmul(torch.transpose(d1, 2, 3), d2).view(
          batch_size, n_mixtures
)

This would allow to evaluate the log_prob of a MoG without instantiating a mdn. Since snpe_c has to do this for every training data point at every iteration, it would be computationally cheaper.

Along with the above refactoring of get_mixture_components, this fully separates the two main steps of calling log_prob() in an mdn.

Log-prob based on cov

Right now, we use sumlogdiag for log_prob. If one does not yet have the cholesky trafo, it would be better to use log(det(cov))

Different init strategy for the means

Means are initialized close to 0. Maybe initializing at more random locations would be better.

Variable number of layers

Requires to write a forward function that loops over a torch.ModuleList.

michaeldeistler avatar Aug 18 '20 09:08 michaeldeistler