score_sde_fast_sampling icon indicating copy to clipboard operation
score_sde_fast_sampling copied to clipboard

Unconditional sampling

Open zqOuO opened this issue 1 year ago • 6 comments

Hi Alexia ! Thanks for sharing your code. I am trying to use your code for unconditional score net sampling, I found all the used checkpoint are trained in conditional, so i tried set grad = score_fn(x, t)/sigmas[timestep]. It doesn't work and i tried many snr setting. Also, the dict will raise an error if i simply change the model.conditional to false and run the code. How to conduct your code on an unconditional score network? Should I change the model.conditional to false and train another? Thanks.

zqOuO avatar Oct 23 '23 02:10 zqOuO

Hi @ziqwnudt,

Most people these days use conditional networks, so there is not unconditional checkpoint.

You would indeed need to retrain using model.conditional=True.

From looking at the code, it seems that only ncsnpp.py will correctly remove the conditioning when conditional=False (see https://github.com/AlexiaJM/score_sde_fast_sampling/blob/5da8f3fe103ee5ac3c3a336f16cc06c9541f0ed9/models/ncsnpp.py#L87). I'm not sure why this is the case, but ncsnpp is the best network so hopefully that will be fine to your use-case. Otherwise you will need to modify the file https://github.com/AlexiaJM/score_sde_fast_sampling/blob/5da8f3fe103ee5ac3c3a336f16cc06c9541f0ed9/models/ddpm.py#L64 to allow for unconditional DDPM.

Alexia

AlexiaJM avatar Oct 23 '23 16:10 AlexiaJM

Thank you Alexia! I think I should train another score network too.Besides, I found you previous work in https://github.com/AlexiaJM/AdversarialConsistentScoreMatching/blob/9575592d3255a4c492728c794fa526dc242e70bc/models/init.py#L51, there is a additional modification about the grad that makes grad = grad_n / sigma, you name it 'Gaussian' and 'dae'. What is the meaning of these two targets? I found Gaussian looks so close to unconditional score network which uses sigma(x) = sigma(x,t)/sigma_t.

zqOuO avatar Oct 24 '23 00:10 zqOuO

Hi @ziqwnudt,

This is because you can set the network to estimates various quantities, as you can always recover the score function from these quantities. The simplest parametrization is for s(x) to estimate the score (-z/std). You can also estimate the noise (z). You can also estimate the real data before noise (x0).

Since x(t) = mu(t)*x0 + std(t)*z (where mu(t)=0 and std(t) depends on the forward process used). This means that the score is -z/sigma = (mu(t)*x0 - xt)/std(t)^2, so you can get the score from either z, x0, or an estimated score. In practice, people estimate the score directly or z directly, because both tends to work better than estimating x0.

https://github.com/AlexiaJM/AdversarialConsistentScoreMatching/blob/9575592d3255a4c492728c794fa526dc242e70bc/losses/dsm.py#L21

AlexiaJM avatar Oct 24 '23 13:10 AlexiaJM

Thank you for your reply! I think got the point. Am I right to understand that if we set the target = Gaussian, s(x) will learn -z which is (xt-mu(t)*x0)/std(t). while the gradient score is (xt-mu(t)*x0)/std(t)^2, so we can recover the gradient score by s(x) = -z/sigma, like the code grad = grad_n / sigma. That is your default target, so all the provided checkpoint with target = Gaussian can recover gradient by grad = grad_n / sigma. Otherwise if I use another score network like NSCN, this s(x) is trained to estimate (xt-mu(t)*x0)/std(t)^2. So I don't need to set grad = grad_n / sigma since this s(x) can recover score directly and grad = grad_n instead of grad_n/sigma is enough to generate new image.

zqOuO avatar Oct 24 '23 14:10 zqOuO

Hi @ziqwnudt,

Yes exactly! You got it right.

Alexia

AlexiaJM avatar Oct 24 '23 14:10 AlexiaJM

Thank you!

zqOuO avatar Oct 24 '23 14:10 zqOuO