stylegan2-flax-tpu
stylegan2-flax-tpu copied to clipboard
🖼 Training StyleGAN2 on TPUs in JAX
StyleGAN2 Flax TPU
This implementation is adapted from the stylegan2 codebase by Matthias Wright.
Specifically, the features we've added allow for better scaling of StyleGAN2 training on TPUs:
- 🏭 Enable data-parallel training on TPU pods (tested on TPU v2 to v4 generations)
- 💾 Google Cloud Storage (GCS) integration/dataset sharding between workers
- 🏖 Quality-of-life improvements (e.g. improved W&B logging)
This food does not exist! Click to see more samples 🍪🍰🍣🍹
🧑🔧 Install
- Clone the repository:
git clone https://github.com/nyx-ai/stylegan2-flax-tpu.git
- Go into the directory:
cd stylegan2-flax-tpu
- Install Jax according to your platform.
- Install requirements:
pip install -r requirements.txt
🖼 Generate Images
We released four 256x256 pretrained models: cookie, cheesecake, sushi and cocktail. Download them from the latest release.
python generate_images.py \
--checkpoint checkpoints/cookie-256.pkl \
--seeds 0 42 420 666 \
--truncation_psi 0.7 \
--out_path generated_images
Check the Colab notebook for more examples:
⚙️ Train Custom Models
Add your images into a folder /path/to/image_dir
:
/path/to/image_dir/
0.jpg
1.jpg
2.jpg
4.jpg
...
and create a TFRecord dataset:
python dataset_utils/images_to_tfrecords.py --image_dir /path/to/image_dir/ --data_dir /path/to/tfrecord
For more detailed instructions please refer to this README.
The following command trains with 128 resolution and batch size of 8.
python main.py --data_dir /path/to/tfrecord
Read more about suitable training parameters here.
🙏 Acknowledgements
- This work is based on Matthias Wright's stylegan2 implementation.
- The project received generous support from Google's TPU Research Cloud (TRC).
- The image datasets were built using the LAION5B index
- We are grateful to Weights & Biases for preserving our sanity