diffusers
diffusers copied to clipboard
Add Flax stable diffusion img2img pipeline
Added Flax SD pipeline for Image 2 Image
TODO:
- [x] Update README
- [x] Fix generation bug
The pipeline is working, but the result is not as expected.
Input image:
Prompt = "A fantasy landscape, trending on artstation"
Outputs:
Params: strength=0.75, num_inference_steps=50, guidance_scale = 7.5
How to run this:
import requests
from io import BytesIO
from PIL import Image
url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
response = requests.get(url)
init_img = Image.open(BytesIO(response.content)).convert("RGB")
init_img = init_img.resize((768, 512))
prompts = "A fantasy landscape, trending on artstation"
dtype=jnp.bfloat16
pipeline, params = FlaxStableDiffusionImg2ImgPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="flax",
dtype=dtype,
)
def create_key(seed=0):
return jax.random.PRNGKey(seed)
rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())
prompt_ids, imgs = pipeline.prepare_inputs(prompt=[prompts]*jax.device_count(), init_image = [init_img]*jax.device_count())
p_params = replicate(params)
prompt_ids = shard(prompt_ids)
imgs = shard(imgs)
output = pipeline(
prompt_ids=prompt_ids,
init_images=imgs,
params=p_params,
prng_seed=rng,
strength=0.75,
num_inference_steps=50,
jit=True,
guidance_scale=7.5,
height=512,width=768).images
The documentation is not available anymore as the PR was closed or merged.
cc: @patil-suraj if you or someone else could review this, It'd be great. Trying to fix the image result, but a review can help fix any other issues.
Fixed it. Forgot to update the for loop with the right timesteps.
Cool ! @dhruvrnaik do you need help finishing this PR?
@patrickvonplaten I am mostly done with this. I just need to update the docs, which I will finish tonight. I have tested this on a TPU v3-8 pod and it works as expected. Ready for a review, otherwise. I will fix the black/isort test fails too.
@patrickvonplaten I am mostly done with this. I just need to update the docs, which I will finish tonight. I have tested this on a TPU v3-8 pod and it works as expected. Ready for a review, otherwise. I will fix the black/isort test fails too.
Excellent! I'll check it out too. Please, let me know when you are done with your changes :) Thanks a lot!
Hey @dhruvrnaik,
Thanks a lot for the PR! Could you maybe run make fix-copies
once and correct the:
❱ 45 class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline): │
│ 46 │ r""" │
│ 47 │ Pipeline for image-to-image generation using Stable Diffusion. │
│ 48 │
│ │
│ /usr/local/lib/python3.8/site-packages/diffusers/pipelines/stable_diffusion/ │
│ pipeline_flax_stable_diffusion_img2img.py:270 in │
│ FlaxStableDiffusionImg2ImgPipeline │
│ │
│ 267 │ │ prompt_ids: jnp.array, │
│ 268 │ │ image: jnp.array, │
│ 269 │ │ params: Union[Dict, FrozenDict], │
│ ❱ 270 │ │ prng_seed: Union[jax.random.KeyArray, jax.Array], │
│ 271 │ │ num_inference_steps: int = 50, │
│ 272 │ │ height: int = 512, │
│ 273 │ │ width: int = 512, │
╰──────────────────────────────────────────────────────────────────────────────╯
AttributeError: module 'jax' has no attribute 'Array'
error message of the PR Documentation to get this merged? :-)
Should be good to go @patrickvonplaten
@dhruvrnaik I'll test it today.
@pcuenca, feel free to merge whenever :-)
Thanks a lot @dhruvrnaik!