score_sde_pytorch icon indicating copy to clipboard operation
score_sde_pytorch copied to clipboard

Likelihood estimation

Open mtailanian opened this issue 1 year ago • 1 comments

Hello Dr. @yang-song , thank you very much for this work.

I'm trying to estimate the likelihood for a given sample. I understand I have to do something very similar to what you do for computing the bpd, here.

As I understand, following eq. (39) in the paper, to obtain $\log(p_0(x_0))$ I have to "correct" $\log(p_T(x_T))$ using the integral of the divergence of the drift function: $\int_0^T \nabla \cdot \overset{-}{f}_\theta (x, t) dt$

In order to obtain a more accurate likelihood estimation using the Skilling-Hutchinson trace estimator, what I'm doing is using the $x$ and $t$ obtained from the SDE solver, like this:

t = solution.t
x = solution.y[:-shape[0], :]

and using these values to plug them into the equation $\epsilon^T \nabla \overset{-}{f}_\theta (x, t) \epsilon $. Then I sample many epsilons and average the results of this equation, to obtain an estimation of div_f.

Finally, I just compute the integral in time, like this:

div_f_integral = torch.trapz(div_f, t, dim=-1)

What do you think, is this correct?

The problem is that the result I'm obtaining is not as expected. When I compute $\log(p_T(x_T)) + \int_0^T \nabla \cdot \overset{-}{f}_\theta (x, t) dt$, I'm supposed to obtain $\log(p_0(x_0))$, but I obtain nonsense values, like log-probs greater than 0...

In summary, what can I do to obtain a more accurate likelihood estimation?

Many thanks in advance!

And any help or hint is very appretiated

mtailanian avatar May 17 '24 22:05 mtailanian

hey @mtailanian were you able to setup the project and train it, if so please do share your frozen pip/ env specs.

bhupender-kaushal avatar Jan 03 '25 10:01 bhupender-kaushal