normalizing-flows icon indicating copy to clipboard operation
normalizing-flows copied to clipboard

NormalizingFlow class in core.py does not provide context in forward_kld

Open ybernaerts opened this issue 3 years ago • 3 comments

Thank you for a repo that's easy to handle with a normalizing flow of one's choice!

I would like to implement a normalizing flow that optimizes multiple target distributions at once depending on the context I would provide to it. Yet, currently, afai, no context can be providided in the .forward_kld method of the NormalizingFlow class.

Would be great if that's added!

Cheers,

Yves

ybernaerts avatar Apr 26 '22 12:04 ybernaerts

I think the same could be said for sampling from a NormalizingFlow model. Would be great if one could provide context.

ybernaerts avatar Apr 26 '22 13:04 ybernaerts

Seems like it's easily fixed with:

   def forward_kld(self, x, context=None):
        """
        Estimates forward KL divergence, see arXiv 1912.02762
        :param x: Batch sampled from target distribution
        :return: Estimate of forward KL divergence averaged over batch
        """
        log_q = torch.zeros(len(x), device=x.device)
        z = x
        for i in range(len(self.flows) - 1, -1, -1):
            z, log_det = self.flows[i].inverse(z, context=context)
            log_q += log_det
        log_q += self.q0.log_prob(z)
        return -torch.mean(log_q)

and:

    def sample(self, num_samples=1, context=None):
        """
        Samples from flow-based approximate distribution
        :param num_samples: Number of samples to draw
        :return: Samples, log probability
        """
        z, log_q = self.q0(num_samples)
        for flow in self.flows:
            z, log_det = flow(z, context=context)
            log_q -= log_det
        return z, log_q

ybernaerts avatar Apr 26 '22 14:04 ybernaerts

but there are other functionalities like probably the .inverse_kld() that can probably capitalize on it too

ybernaerts avatar Apr 26 '22 14:04 ybernaerts

This is available here in their previous codebase nflows

mtsatsev avatar May 25 '23 15:05 mtsatsev

Hi @ybernaerts and @mtsatsev,

I added support for conditional normalizing flows, as shown in this example. More details are discussed in the related Issue #41.

Best regards, Vincent

VincentStimper avatar Jul 14 '23 09:07 VincentStimper

The additions are part of the new version of the package on PyPI. Hence, I'll close this issue. Feel free to create a new issue if you're still missing features.

VincentStimper avatar Jul 23 '23 09:07 VincentStimper