diffusion_models icon indicating copy to clipboard operation
diffusion_models copied to clipboard

Incorrect objective in Jax Denoising Score Matching

Open aniquetahir opened this issue 1 year ago • 0 comments

def denoising_score_matching(scorenet, samples, key, sigma=0.01):
    noise = jax.random.normal(key, samples.shape)
    perturbed_samples = samples + noise * sigma
    target = -noise / sigma
    scores = scorenet(perturbed_samples)
    loss = 1 / 2. * ((scores - target) ** 2).sum(axis=-1).mean(axis=0)
    return loss

denoising_score_matching(model, data[:10], jax.random.PRNGKey(0))

Here target = -noise/sigma should be replaced with target = -noise

aniquetahir avatar Sep 03 '22 17:09 aniquetahir