diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

Add Flax stable diffusion img2img pipeline

Open dhruvrnaik opened this issue 2 years ago • 9 comments

Added Flax SD pipeline for Image 2 Image

TODO:

  • [x] Update README
  • [x] Fix generation bug

The pipeline is working, but the result is not as expected.

Input image:download (1) Prompt = "A fantasy landscape, trending on artstation" Outputs: download (2)

Params: strength=0.75, num_inference_steps=50, guidance_scale = 7.5

How to run this:

import requests
from io import BytesIO
from PIL import Image

url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"

response = requests.get(url)
init_img = Image.open(BytesIO(response.content)).convert("RGB")
init_img = init_img.resize((768, 512))
prompts = "A fantasy landscape, trending on artstation"
dtype=jnp.bfloat16
pipeline, params = FlaxStableDiffusionImg2ImgPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4", revision="flax",
    dtype=dtype,
)
def create_key(seed=0):
    return jax.random.PRNGKey(seed)
rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())

prompt_ids, imgs = pipeline.prepare_inputs(prompt=[prompts]*jax.device_count(), init_image = [init_img]*jax.device_count())
p_params = replicate(params)
prompt_ids = shard(prompt_ids)
imgs = shard(imgs)

output = pipeline(
    prompt_ids=prompt_ids, 
    init_images=imgs, 
    params=p_params, 
    prng_seed=rng, 
    strength=0.75, 
    num_inference_steps=50, 
    jit=True, 
    guidance_scale=7.5,
height=512,width=768).images

dhruvrnaik avatar Nov 21 '22 15:11 dhruvrnaik

The documentation is not available anymore as the PR was closed or merged.

cc: @patil-suraj if you or someone else could review this, It'd be great. Trying to fix the image result, but a review can help fix any other issues.

dhruvrnaik avatar Nov 22 '22 15:11 dhruvrnaik

Fixed it. Forgot to update the for loop with the right timesteps.

download (3)

dhruvrnaik avatar Nov 25 '22 03:11 dhruvrnaik

Cool ! @dhruvrnaik do you need help finishing this PR?

patrickvonplaten avatar Nov 30 '22 12:11 patrickvonplaten

@patrickvonplaten I am mostly done with this. I just need to update the docs, which I will finish tonight. I have tested this on a TPU v3-8 pod and it works as expected. Ready for a review, otherwise. I will fix the black/isort test fails too.

dhruvrnaik avatar Dec 02 '22 01:12 dhruvrnaik

@patrickvonplaten I am mostly done with this. I just need to update the docs, which I will finish tonight. I have tested this on a TPU v3-8 pod and it works as expected. Ready for a review, otherwise. I will fix the black/isort test fails too.

Excellent! I'll check it out too. Please, let me know when you are done with your changes :) Thanks a lot!

pcuenca avatar Dec 02 '22 11:12 pcuenca

Hey @dhruvrnaik,

Thanks a lot for the PR! Could you maybe run make fix-copies once and correct the:

 ❱  45 class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):       │
│    46 │   r"""                                                               │
│    47 │   Pipeline for image-to-image generation using Stable Diffusion.     │
│    48                                                                        │
│                                                                              │
│ /usr/local/lib/python3.8/site-packages/diffusers/pipelines/stable_diffusion/ │
│ pipeline_flax_stable_diffusion_img2img.py:270 in                             │
│ FlaxStableDiffusionImg2ImgPipeline                                           │
│                                                                              │
│   267 │   │   prompt_ids: jnp.array,                                         │
│   268 │   │   image: jnp.array,                                              │
│   269 │   │   params: Union[Dict, FrozenDict],                               │
│ ❱ 270 │   │   prng_seed: Union[jax.random.KeyArray, jax.Array],              │
│   271 │   │   num_inference_steps: int = 50,                                 │
│   272 │   │   height: int = 512,                                             │
│   273 │   │   width: int = 512,                                              │
╰──────────────────────────────────────────────────────────────────────────────╯
AttributeError: module 'jax' has no attribute 'Array'

error message of the PR Documentation to get this merged? :-)

patrickvonplaten avatar Dec 15 '22 20:12 patrickvonplaten

Should be good to go @patrickvonplaten

dhruvrnaik avatar Dec 19 '22 05:12 dhruvrnaik

@dhruvrnaik I'll test it today.

pcuenca avatar Dec 19 '22 10:12 pcuenca

@pcuenca, feel free to merge whenever :-)

patrickvonplaten avatar Dec 20 '22 00:12 patrickvonplaten

Thanks a lot @dhruvrnaik!

pcuenca avatar Dec 20 '22 15:12 pcuenca