pyknos
pyknos copied to clipboard
Desirable MDN features for SNPE-C / SNPE-A
After having implemented non-atomic SNPE-C, I am writing this issue to keep track of things that would have been desirable to exist in mdn
.
Get mixture components
The only non-protected methods of mdn
are log_prob()
and sample()
. It would be great to have a non-protected get_mixture_components
. Unlike the already existing, protected _get_mixture_components()
, it should also call the embedding_net
.
Evaluating the log_prob
The following code should be put in a separate static method called evaluate_mixture_log_prob()
:
batch_size, n_mixtures, output_dim = means.size()
inputs = inputs.view(-1, 1, output_dim)
# Split up evaluation into parts.
a = logits - torch.logsumexp(logits, dim=-1, keepdim=True)
b = -(output_dim / 2.0) * np.log(2 * np.pi)
c = sumlogdiag
d1 = (inputs.expand_as(means) - means).view(
batch_size, n_mixtures, output_dim, 1
)
d2 = torch.matmul(precisions, d1)
d = -0.5 * torch.matmul(torch.transpose(d1, 2, 3), d2).view(
batch_size, n_mixtures
)
This would allow to evaluate the log_prob
of a MoG without instantiating a mdn
. Since snpe_c has to do this for every training data point at every iteration, it would be computationally cheaper.
Along with the above refactoring of get_mixture_components
, this fully separates the two main steps of calling log_prob()
in an mdn.
Log-prob based on cov
Right now, we use sumlogdiag
for log_prob. If one does not yet have the cholesky trafo, it would be better to use log(det(cov))
Different init strategy for the means
Means are initialized close to 0. Maybe initializing at more random locations would be better.
Variable number of layers
Requires to write a forward function that loops over a torch.ModuleList
.