stable-diffusion-jax icon indicating copy to clipboard operation
stable-diffusion-jax copied to clipboard

TODOs:

  • [x] Finish implementing the UNet2D model in modeling_unte2d.py. Port weights of any existing LDM unet from diffusers and verify equivalence. I've added the skleton of modules that we need to implement in the file.
  • [x] Adapt the PNDMScheduler from diffusers for JAX: Use jnp arrays and make it stateless.
  • [x] Add the KL module from (here)[https://github.dev/CompVis/stable-diffusion] in modeling_vae.py file. For inference we don't really need it, but would be nice to have for completeness. Port the weights of any existing KL VAE and verify equivalence.
  • [x] Add an inference loop in pipeline_stabel_diffusion. We should able to jit/pmap the loop to deploy on TPUs.