score_sde_pytorch
score_sde_pytorch copied to clipboard
PC sampler mismatched?
Hello, thanks for your interesting work!
I have a question about your implementation of PC sampler:
def pc_sampler(model):
with torch.no_grad():
# Initial sample
x = sde.prior_sampling(shape).to(device)
timesteps = torch.linspace(sde.T, eps, sde.N, device=device)
for i in range(sde.N):
t = timesteps[i]
vec_t = torch.ones(shape[0], device=t.device) * t
x, x_mean = corrector_update_fn(x, vec_t, model=model)
x, x_mean = predictor_update_fn(x, vec_t, model=model)
return inverse_scaler(x_mean if denoise else x), sde.N * (n_steps + 1)
Why you start by correcter instead of predictor as in Alg 1. of your original paper? Is there any reason? Thank you very much!
I have the same question after reviewing these lines of code. @NguyenHai7120 Have you figured it out?
Is it because after the prior sampling, you need to run a corrector first?
Same question here. I guess it makes somewhat sense to use the corrector on the prior sampling. But even then the corrector is never applied to the final sample?