graph_weather icon indicating copy to clipboard operation
graph_weather copied to clipboard

gencast - intermediate denoiser output loss increases during sampling

Open charlescharles opened this issue 5 months ago • 0 comments

I trained a 128x64 model (with sparse=False); if I record the ~39 denoiser outputs during the course of a single sampling loop and compute their mse loss (not weighted by lambda_sigma because this doesn't make sense during inference), the loss trajectory looks something like this:

image

If I visualize the denoiser output at the very first inference step, the denoiser output at the very last inference step, and also the sampling result, then indeed the denoiser output at inference step 0 looks much closer to the target:

image

I would expect the mse loss of the intermediate denoiser outputs to decrease during sampling.

charlescharles avatar Aug 29 '24 16:08 charlescharles