JAX Integration
JAX Integration
This issue will be used as a tracker to integrate Stable Diffusion in JAX natively to diffusers. This will enable many cool use cases noteably running stable diffusion on a google colab.
General design:
We will make loosen the forced PyTorch dependency and instead force the user to either install PyTorch or JAX. Then we will mirror the following "base" classes to be JAX compatible:
ModelMixin: https://github.com/patil-suraj/stable-diffusion-jax/pull/10 we should add a FlaxModelMixin class here.
FlaxDiffusionPipeline: https://github.com/huggingface/diffusers/blob/25a51b63ca75e1351069bee87a0fb3df5abb89c3/src/diffusers/pipeline_utils.py#L76 we should add a FlaxDiffusionPipeline here.
Note: ModelMixin should be made state-less by default. E.g. weights will not be saved. Also contrary to transformers should we maybe only work with flax.linen.Module classes here @patil-suraj - I don't really think we need the UNetConditionModel and UNetConditionModule design here - we could just go for class UNetConditionModel(nn.Module): here and make sure everything stays stateless no?
TODO:
- [ ] 1. Make
diffusersframework independent. This will require some general changes tosetup.pyand our automation tools - [x] 2. Add
FlaxModelMixin: https://github.com/huggingface/diffusers/pull/493 Here we can take a lot from https://github.com/patil-suraj/stable-diffusion-jax/pull/10/files but I'm not sure we should follow thetransformersdesign here 1-to-1 . Will also ask some google-folks here - [ ] 3. Add all the modeling code under
unet_2d_condition_flax.py... - [ ] #478
- [x] 5. Add PNDM scheduler under
scheduling_pndm_flax.py - [ ] 6. Tests
- [ ] 7. Create pipeline and also
FlaxDiffusionPipeline
Happy to take over 1. and finish today and then look into 4. once 3. is done.
@mishig25 do you want to do 2.? (happy to guide you here a bit if you have questions. Also we need to discuss the design here a bit offline maybe)
- & 5. @pcuenca do you want to take this? (think 3. is more important here)
The other parts we can see tomorrow maybe :-)
ModelMixin should be made state-less by default. E.g. weights will not be saved. Also contrary to transformers should we maybe only work with flax.linen.Module classes here @patil-suraj - I don't really think we need the UNetConditionModel and UNetConditionModule design here - we could just go for class UNetConditionModel(nn.Module): here and make sure everything stays stateless no?
Yes, let's only work with flax.linen.Module, adding the wrapper indeed creates some issues when using these models as submodules in other models.
For 4. we could adapt this https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py#L94
Yes, I'd like to take point 2 and happy to discuss the details 👍
We also need to create a Flax notebook after all is done. I wrote a draft that uses @patil-suraj's repo.
Sure, I'll get started with 3.
I am gonna do 5
Asked Flax team about design here: https://github.com/google/flax/discussions/2454
I did some changes in @patil-suraj's repo to easier jit the safety checker, they could be useful when we migrate it to diffusers: https://github.com/patil-suraj/stable-diffusion-jax/pull/11
@patrickvonplaten are there any open items for Jax/Flax integration? I'd love to contribute
I think this should be complete now, please @patil-suraj @patrickvonplaten feel free to reopen otherwise.