ddpo-pytorch
ddpo-pytorch copied to clipboard
Reproduction of DDPO paper (RLHF for diffusion)
RLHF for Diffusion Models
This is an implementation of Training Diffusion Models with Reinforcement Learning. This is meant as an educational codebase, with lots of comments explaining the code and only basic features. It currently only implements LAION aesthetic classifier as a reward function, but more examples will be added soon.
Tutorial blog post coming soon
This codebase is just for educational purposes, another codebase for scalable training is being developed here.
Installation
git clone https://github.com/tmabraham/ddpo-pytorch.git
cd ddpo-pytorch
pip install -r requirements.txt
Usage
It's as simple as running:
python main.py
To save memory (you'll likely need it), use the arguments --enable_attention_slicing, --enable_xformers_memory_efficient_attention, and --enable_grad_checkpointing.
Results
Original samples:
After training for 50 epochs: