Better training loop for diffusion models and flow matching
🚀 Feature Request
Diffusion models, in essence, leverage "data augmentation," which allows them to be trained for longer on smaller datasets without overfitting. They could further benefit from more advanced default settings, such as cosine one-cycle learning rate schedulers, exponential moving average gradient accumulation, etc.
Additionally, it introduces more stochasticity in the loss, which makes it more challenging to efficiently detect convergence automatically. In the default loop we should target good performance on the default solve schedule, hence related to #1437 .
Describe the solution you'd like
This request is somewhat open-ended, but the primary goal would be:
- Achieving strong performance with the default settings.
- Implementing convergence checks that stop training within a reasonable time frame, minimizing the impact on performance and preventing overfitting.
This task is exploratory by nature. Feel free to propose or try out different solutions.
Assigned myself here for now, happy for anyone else to jump on at any point
I am working with people from the deep-inverse library who have similar needs to train diffusion models. They have a lot of experience in training state-of-the-art denoisers for imaging and MRI problems and they are currently improving their trainers for score and flow matching.
I think they might have some interesting points, and maybe there are some ideas that can be shared by the two libraries.
cc @matthieutrs, @Andrewwango
It would also be worth checking out https://github.com/probabilists/azula
Thanks @tomMoral @janfb! DeepInverse and azula (cc @francois-rozet) are both super relevant! Would be grateful for input from the devs.