diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

Convert Flax trained model to PyTorch

Open jorahn opened this issue 2 years ago • 8 comments

Is your feature request related to a problem? Please describe. I'm using the new Flax pipeline for dreambooth training and want to be able to use the trained model on systems that don't run jax.

Describe the solution you'd like Add from_flax parameter to PyTorch model.from_pretrained() functions like Flax models have from_pt parameters in .from_pretrained() to enable cross-compatibility. Probably here or here.

Describe alternatives you've considered A script to convert Flax msgpack to Diffusers and Flax msgpack to SD Checkpoint (i.e. /scripts directory).

Additional context See here for the implementation of this feature in transformers.

jorahn avatar Nov 06 '22 19:11 jorahn

Good idea 🤩 maybe like this: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flax_pytorch_utils.py#L242-L352

camenduru avatar Nov 06 '22 23:11 camenduru

Thanks, I'll give it a try

jorahn avatar Nov 07 '22 18:11 jorahn

Let me know how it goes @jorahn - curious to hear if it works :-)

patrickvonplaten avatar Nov 09 '22 19:11 patrickvonplaten

+1 curious

camenduru avatar Nov 10 '22 05:11 camenduru

Ok, so I have a lokal version, that works for this:

from diffusers import FlaxStableDiffusionPipeline, StableDiffusionPipeline
from diffusers import AutoencoderKL, UNet2DConditionModel
import jax.numpy as jnp

pipe, params = FlaxStableDiffusionPipeline.from_pretrained('CompVis/stable-diffusion-v1-4', revision='bf16', dtype=jnp.bfloat16)
pipe.save_pretrained('output-flax', params)

vae = AutoencoderKL.from_pretrained('output-flax/vae', from_flax=True)
unet = UNet2DConditionModel.from_pretrained('output-flax/unet', from_flax=True)

The pipeline doesn't work yet:

pipe2 = StableDiffusionPipeline.from_pretrained('output-flax', from_flax=True)

This leads to OSError: Error no file named diffusion_pytorch_model.bin found in directory output-flax/unet.. I'll try to put this into a PR so you can take a look at it and maybe help to figure out what's missing.

jorahn avatar Nov 10 '22 07:11 jorahn

Will try to help with the PR!

patrickvonplaten avatar Nov 15 '22 22:11 patrickvonplaten

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Dec 10 '22 15:12 github-actions[bot]

~~Actually superseded by #1241.~~ Sorry, I thought this was a PR.

pcuenca avatar Dec 19 '22 10:12 pcuenca

download https://github.com/camenduru/diffusers/blob/from_flax_v2/src/diffusers/modeling_pytorch_flax_utils.py https://huggingface.co/camenduru/plushies-pt

camenduru avatar Dec 24 '22 07:12 camenduru

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Jan 17 '23 15:01 github-actions[bot]

Closing this PR as one can simply convert the model by loading it into PyTorch now:

from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained("./flax_stable_diffusion", from_flax=True)

patrickvonplaten avatar Jan 22 '23 19:01 patrickvonplaten