diffusers
                                
                                 diffusers copied to clipboard
                                
                                    diffusers copied to clipboard
                            
                            
                            
                        Convert Flax trained model to PyTorch
Is your feature request related to a problem? Please describe. I'm using the new Flax pipeline for dreambooth training and want to be able to use the trained model on systems that don't run jax.
Describe the solution you'd like Add from_flax parameter to PyTorch model.from_pretrained() functions like Flax models have from_pt parameters in .from_pretrained() to enable cross-compatibility. Probably here or here.
Describe alternatives you've considered A script to convert Flax msgpack to Diffusers and Flax msgpack to SD Checkpoint (i.e. /scripts directory).
Additional context See here for the implementation of this feature in transformers.
Good idea 🤩 maybe like this: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flax_pytorch_utils.py#L242-L352
Thanks, I'll give it a try
Let me know how it goes @jorahn - curious to hear if it works :-)
+1 curious
Ok, so I have a lokal version, that works for this:
from diffusers import FlaxStableDiffusionPipeline, StableDiffusionPipeline
from diffusers import AutoencoderKL, UNet2DConditionModel
import jax.numpy as jnp
pipe, params = FlaxStableDiffusionPipeline.from_pretrained('CompVis/stable-diffusion-v1-4', revision='bf16', dtype=jnp.bfloat16)
pipe.save_pretrained('output-flax', params)
vae = AutoencoderKL.from_pretrained('output-flax/vae', from_flax=True)
unet = UNet2DConditionModel.from_pretrained('output-flax/unet', from_flax=True)
The pipeline doesn't work yet:
pipe2 = StableDiffusionPipeline.from_pretrained('output-flax', from_flax=True)
This leads to OSError: Error no file named diffusion_pytorch_model.bin found in directory output-flax/unet.. I'll try to put this into a PR so you can take a look at it and maybe help to figure out what's missing.
Will try to help with the PR!
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
~~Actually superseded by #1241.~~ Sorry, I thought this was a PR.
 https://github.com/camenduru/diffusers/blob/from_flax_v2/src/diffusers/modeling_pytorch_flax_utils.py
https://huggingface.co/camenduru/plushies-pt
https://github.com/camenduru/diffusers/blob/from_flax_v2/src/diffusers/modeling_pytorch_flax_utils.py
https://huggingface.co/camenduru/plushies-pt
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
Closing this PR as one can simply convert the model by loading it into PyTorch now:
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained("./flax_stable_diffusion", from_flax=True)