diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

JAX Integration

Open patrickvonplaten opened this issue 3 years ago • 9 comments

JAX Integration

This issue will be used as a tracker to integrate Stable Diffusion in JAX natively to diffusers. This will enable many cool use cases noteably running stable diffusion on a google colab.

General design:

We will make loosen the forced PyTorch dependency and instead force the user to either install PyTorch or JAX. Then we will mirror the following "base" classes to be JAX compatible:

ModelMixin: https://github.com/patil-suraj/stable-diffusion-jax/pull/10 we should add a FlaxModelMixin class here. FlaxDiffusionPipeline: https://github.com/huggingface/diffusers/blob/25a51b63ca75e1351069bee87a0fb3df5abb89c3/src/diffusers/pipeline_utils.py#L76 we should add a FlaxDiffusionPipeline here.

Note: ModelMixin should be made state-less by default. E.g. weights will not be saved. Also contrary to transformers should we maybe only work with flax.linen.Module classes here @patil-suraj - I don't really think we need the UNetConditionModel and UNetConditionModule design here - we could just go for class UNetConditionModel(nn.Module): here and make sure everything stays stateless no?

TODO:

  • [ ] 1. Make diffusers framework independent. This will require some general changes to setup.py and our automation tools
  • [x] 2. Add FlaxModelMixin: https://github.com/huggingface/diffusers/pull/493 Here we can take a lot from https://github.com/patil-suraj/stable-diffusion-jax/pull/10/files but I'm not sure we should follow the transformers design here 1-to-1 . Will also ask some google-folks here
  • [ ] 3. Add all the modeling code under unet_2d_condition_flax.py ...
  • [ ] #478
  • [x] 5. Add PNDM scheduler under scheduling_pndm_flax.py
  • [ ] 6. Tests
  • [ ] 7. Create pipeline and also FlaxDiffusionPipeline

Happy to take over 1. and finish today and then look into 4. once 3. is done.

@mishig25 do you want to do 2.? (happy to guide you here a bit if you have questions. Also we need to discuss the design here a bit offline maybe)

  1. & 5. @pcuenca do you want to take this? (think 3. is more important here)

The other parts we can see tomorrow maybe :-)

patrickvonplaten avatar Sep 12 '22 10:09 patrickvonplaten

ModelMixin should be made state-less by default. E.g. weights will not be saved. Also contrary to transformers should we maybe only work with flax.linen.Module classes here @patil-suraj - I don't really think we need the UNetConditionModel and UNetConditionModule design here - we could just go for class UNetConditionModel(nn.Module): here and make sure everything stays stateless no?

Yes, let's only work with flax.linen.Module, adding the wrapper indeed creates some issues when using these models as submodules in other models.

patil-suraj avatar Sep 12 '22 10:09 patil-suraj

For 4. we could adapt this https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py#L94

patil-suraj avatar Sep 12 '22 10:09 patil-suraj

Yes, I'd like to take point 2 and happy to discuss the details 👍

mishig25 avatar Sep 12 '22 10:09 mishig25

We also need to create a Flax notebook after all is done. I wrote a draft that uses @patil-suraj's repo.

pcuenca avatar Sep 12 '22 10:09 pcuenca

Sure, I'll get started with 3.

pcuenca avatar Sep 12 '22 10:09 pcuenca

I am gonna do 5

kashif avatar Sep 12 '22 10:09 kashif

Asked Flax team about design here: https://github.com/google/flax/discussions/2454

patrickvonplaten avatar Sep 12 '22 11:09 patrickvonplaten

I did some changes in @patil-suraj's repo to easier jit the safety checker, they could be useful when we migrate it to diffusers: https://github.com/patil-suraj/stable-diffusion-jax/pull/11

pcuenca avatar Sep 16 '22 13:09 pcuenca

@patrickvonplaten are there any open items for Jax/Flax integration? I'd love to contribute

dhruvrnaik avatar Oct 03 '22 21:10 dhruvrnaik

I think this should be complete now, please @patil-suraj @patrickvonplaten feel free to reopen otherwise.

pcuenca avatar Oct 27 '22 09:10 pcuenca