score_sde_pytorch icon indicating copy to clipboard operation
score_sde_pytorch copied to clipboard

Bug report on ReverseDiffusionPredictor

Open dongli96 opened this issue 1 year ago • 1 comments

Hello,

In ./sampling.py, line 190, you have a class ReverseDiffusionPredictor:

@register_predictor(name='reverse_diffusion')
class ReverseDiffusionPredictor(Predictor):
  def __init__(self, sde, score_fn, probability_flow=False):
    super().__init__(sde, score_fn, probability_flow)

  def update_fn(self, x, t):
    f, G = self.rsde.discretize(x, t)
    z = torch.randn_like(x)
    x_mean = x - f
    x = x_mean + G[:, None, None, None] * z
    return x, x_mean

in which the method update_fn looks incorrect to me. The item with score seemed to be lacking. Should it be something like x = x_mean + G^2*score + G*z?

I am looking forward to hearing from you.

dongli96 avatar Oct 29 '24 17:10 dongli96

The score is part of the drift f

sascha-holl avatar May 01 '25 20:05 sascha-holl