diffusion_models
diffusion_models copied to clipboard
Incorrect objective in Jax Denoising Score Matching
def denoising_score_matching(scorenet, samples, key, sigma=0.01):
noise = jax.random.normal(key, samples.shape)
perturbed_samples = samples + noise * sigma
target = -noise / sigma
scores = scorenet(perturbed_samples)
loss = 1 / 2. * ((scores - target) ** 2).sum(axis=-1).mean(axis=0)
return loss
denoising_score_matching(model, data[:10], jax.random.PRNGKey(0))
Here target = -noise/sigma
should be replaced with target = -noise