[Feat]: Same timestep used for all validation steps
What happened?
In this PR, the validation timestep was fixed to 499: https://github.com/Nerogar/OneTrainer/pull/660 This was to make sure each validation uses the same timesteps, and therefore make the validations comparable to each other. Validation loss became meaningful.
I don't think this is the right solution though. The timesteps should be the same for each validation, but not for each validation step within the same validation.
Only validating against 499 might be somewhat meaningful, but especially in models with timestep shifting, 499 is not a particularily important timestep.
What did you expect would happen?
Validation timesteps should be deterministic, but only between validations, not between validation steps. Within the same validation, timesteps should be randomly distributed.
Relevant log output
Generate and upload debug_report.log
No response
Sorry I missed that, this was already noted in the original PR:
This pr is just a very basic version, using the existing deterministic timestep option to lock timestep to 500, and overriding the seed to 0, but a better implementation would just seed once at the start of validation, and use different timesteps and noise for each image sampled, to be more robust. Unfortunately I couldn't find a way to stop the validation dataloader from shuffling, but this version is already more useful than the previous fully random validation.
So I'll change this to an enhancement
I agree that it should be multiple deterministic timesteps instead of fixed at a single timestep, but the issue was that the dataloader didn't have an option to disable shuffling, which would make the validation nondeterministic. In order to add the functionality to onetrainer, first https://github.com/Nerogar/mgds would need to have that option added.
Adding it to the dataloader is pretty simple. But then you still need to find some way to select useful timesteps. And I think that can be pretty complicated.
- Spreading them out uniformly (100, 200, ... if you have 10 samples for example) won't be useful if you sample the timesteps non uniformly during inference. And all the steps will change if you add a single new image to the data set.
- "Uniform" sampling from the training distribution could solve the first issue, but not the second one
- Random (but deterministic) sampling could solve the second issue, but not the first one since it can add a lot of bias for small validation sets
- Some algorithm that samples from non uniformly distributed buckets might work better, but it still has some of the issues mentioned above.
I think all of those options are valid, because the goal is just to measure relative change. Sampling timesteps mostly from the middle of the schedule will show the largest change, because if you look at loss at every timestep before and after training, the whole curve will shift, but with both ends mostly constrained, like this for SD1.5:
(edit in case it wasn't clear, that's not a training run, that's just post evaluation by using a fixed sample of images, and random noise, and sweeping through every timestep.)
For RF models the shape of the curves will be different, but the same idea is still true, most of the learning happens in the middle timesteps, not at either end. That's why SD3 used logit normal distribution sampling, although I don't think it was actually ideal for training. That could be a good option for validation though, it would sample the timesteps with a bias to the middle.
In theory you could get different movement on different timesteps, so it could be useful to log them separately instead of averaging, AFAIK that's what diffusion pipe does, but in practice they're very strongly correlated, so you're not really missing much by just averaging them, or only sampling a single fixed timestep.
Thanks for your replies! It would be interesting to see this correlation graph for a timestep shifted model, because there we train mostly higher timesteps but still validate at 500:
And even if you don't shift during training, during inference shifting is still used: The correlation will still be there, but I wonder how well it still correlates with human preference, because whatever the model learns at timestep 500 is not very relevant for the overall generated image.
Regarding @Nerogar's options above, these two make the most sense to me:
- for validation with many steps, a randomly sampled timesteps using the timestep distribution
- for validation with very few steps, such as 1, you can have bad luck and sample a timestep at the very end of the distribution. For these cases it's probably best to keep a fixed timestep in the middle - but shift it to the middle of the distribution, not a fixed 500.
It would be interesting to see this correlation graph for a timestep shifted model, because there we train mostly higher timesteps but still validate at 500:
The loss/timestep correlation curve will have a different shape for RF models, but it's because of the noise schedule (lerp) and objective (flow velocity), not shift. Shift shouldn't change the signal to noise ratio at a given timestep, it only changes what timesteps are selected.
I agree with your reasoning about which options would be useful, although it might need to be a manual selection, either sample timesteps from the training distribution or input a manual timestep.
The loss/timestep correlation curve will have a different shape for RF models, but it's because of the noise schedule (lerp) and objective (flow velocity), not shift. Shift shouldn't change the signal to noise ratio at a given timestep, it only changes what timesteps are selected.
shift changes which timesteps are most important for the resulting image. If most of my image is generated in the [800:1000] range, validation at 500 might not be as meaningful.
I wondered why in one of my recent trainings it seemed that the samples are still improving, while validation is already worsening. So I suspected this issue and validated not only on timestep 500, but on list(range(0, 1000, 50)) + [999], all on the same validation sample:
On the left, validation at timestep 500; on the right, the average of the validations mentioned above.
So you could say: it's clear, validation at 500 is not enough, we need to sample at multiple timesteps as discussed above. But not so fast:
Why are many of these only slightly improving or flat? Compare the scale of timestep 50 and timestep 999:
If we would take the average of validation losses at different timesteps as discussed above, we would strongly overweight the higher timesteps.
You'd have to take the average of their relative change or take timestep SNR into account somehow.
@dxqbYD If you want to get into picking individual timesteps or a weighted average, you would also want to consider the content you're training on. Style is mostly high frequency, so you would probably expect to see more movement on the high SNR timesteps, while structure is lower frequency, so when learning that more movement should happen on the noisy/low SNR timesteps. It's always going to be learning some combination of both though, so depending on what's easier or harder it might converge faster on some SNRs than others.
In my own training scripts I sample validation timesteps using the same distribution as training timesteps, using uniform sampling with shift. If you trust that whatever distribution you're using for training is good for placing the correct weight per timestep to converge at an optimal speed, then it's probably safe to assume that using the same timestep distribution for validation is also good.
I think you should also pay attention to dataset and validation split size though. All your validation curves are very noisy, which tells me you're probably working with a small dataset. In that case some amount of overfitting can improve sample image quality, like you observed. Further training past the "optimal" point is trading off flexibility/variety for quality up to a point. You don't hit the true memorization/collapse phase until much later.
It's different for larger datasets though. If you have enough data (not just number of images, but also the amount of information/variety within them), it's hard to overfit at all when training at a small scale. Instead you'll just see both training and validation loss decreasing exponentially towards some unreachable threshold. Here's an example with Wan 1.3b and a 50k video dataset:
You are completely right. It was a small test data set that is easy to overfit.
I tend to think that there is no good way to summarize the validation loss over different timesteps into 1 metric. Unfortunately tensorboard is not very good at displaying multiple values in 1 graph.
@spacepxl what do you think of this? https://github.com/Nerogar/OneTrainer/pull/821
I think choosing multiple timesteps is fine, and reporting them separately is fine, but you should absolutely still report the average as well. For larger datasets I still think it makes more sense to sample timesteps from the training distribution and average them, you'll get much cleaner validation results that way.
Looking at the example validation curves on the PR though, are you sure you didn't break determinism? That's a lot of noise and no visible drop.
I think choosing multiple timesteps is fine, and reporting them separately is fine, but you should absolutely still report the average as well. For larger datasets I still think it makes more sense to sample timesteps from the training distribution and average them, you'll get much cleaner validation results that way.
the total average is still reported, just not in the screenshots. What's currently not reported is the average per concept separated, but the losses of all timesteps averaged. I don't think this is very useful (see above in this discussion), but can be added if others do.
Looking at the example validation curves on the PR though, are you sure you didn't break determinism? That's a lot of noise and no visible drop.
Yes. This was a hundred steps to make a screenshot, not any coherent training session.