score_sde_pytorch
score_sde_pytorch copied to clipboard
Bug report on ReverseDiffusionPredictor
Hello,
In ./sampling.py, line 190, you have a class ReverseDiffusionPredictor:
@register_predictor(name='reverse_diffusion')
class ReverseDiffusionPredictor(Predictor):
def __init__(self, sde, score_fn, probability_flow=False):
super().__init__(sde, score_fn, probability_flow)
def update_fn(self, x, t):
f, G = self.rsde.discretize(x, t)
z = torch.randn_like(x)
x_mean = x - f
x = x_mean + G[:, None, None, None] * z
return x, x_mean
in which the method update_fn looks incorrect to me. The item with score seemed to be lacking. Should it be something like x = x_mean + G^2*score + G*z?
I am looking forward to hearing from you.
The score is part of the drift f