pixart sigma: callbacks(interrupt, latent, pos/neg embeds) and cfg_rescale
What does this PR do?
This PR refactors the pipeline to mirror other common pipelines by adding in callback_on_step_end and callback_on_step_end_tensor_inputs, along with cfg rescaling. In the callbacks, you can interrupt, retrieve latents and/or retrieve pos/neg embeds. The older callback method will continue to provide steps_idx, t and latents, but that's it.
I've added in deprecation warnings for those still using the legacy callback and callback_steps method, instead of the newer callback_on_step_end and callback_on_step_end_tensor_inputs method, as well added in an error for if you tried to use both at the same time.
To sum it up, this adds:
- [x]
callback_on_step_endandcallback_on_step_end_tensor_inputs, which allow you to obtain latents and pos/neg embeds - [x] deprecation warnings for the older
callbackandcallback_stepsmethods - [x] the ability to use
self._interrupt=Trueoncallback_on_step_end - [x] cfg rescaling
Some snippets of the code in my app that I tested it with to verify that it works:
def interrupt_callback(self, i, t, callback_kwargs):
# using latching variable for onkeypress event to trigger
if not queue_latch:
self._interrupt = True
latents = callback_kwargs["latents"]
with torch.no_grad():
image = pipe.vae.decode(latents / 0.13025, return_dict=False)[0]
image = pipe.image_processor.postprocess(image, output_type="pil")
image[0].save(f"{i}.png")
return callback_kwargs
and
latents = pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
num_images_per_prompt=1,
height=height,
width=width,
num_inference_steps=steps,
guidance_scale=cfg,
generator=seedgen,
callback_on_step_end=interrupt_callback, ###############
callback_on_step_end_tensor_inputs=["latents"], ########
output_type="latent",
).images
A test image showing the latent callbacks working at each step(can also be used to generate realtime previews in apps like shown in my update 4 comment):
Example of the interrupt callback working while using my app:
Example of cfg rescale working(gif compression degrades quality a lot):
Before submitting
- [x] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
- [x] Did you write any new necessary tests?
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.
@yiyixuxu I think it's about done for real this time, let me know if there's anything else. Check my original post for more details on the changes.
Added in cfg rescaling, defaults to zero if not used, based on Common Diffusion Noise Schedules and Sample Steps are Flawed section 3.4. This adds feature parity with a lot of other common pipelines.
With cfg rescale factor set to 0.0
With cfg rescale factor set to 0.7
Open these in new tabs and click back and forth to see the minor changes. They are subtle in these two, but in some scenes, they are a lot more obvious.
thanks! I think this PR should focus on adding the new callback and cfg_rescale, I think we should not include these changes introduced for
encode_promptin this PR
~~I addressed this in the code comments~~ I reverted the negative_prompt changes, will open a new PR eventually after this one is finished.
additionally, can you test out the dynamic classifier-free guidance on pixart using the new callback API? https://huggingface.co/docs/diffusers/using-diffusers/callback#dynamic-classifier-free-guidance
I'll try to take a look at it as well, but after I figure out how to handle the old callback/callbacksteps deprecation.
@yiyixuxu Alright, I reworked the legacy callback/callback_steps back in. ~~The legacy callback will require (self, step_idx, t, latents), but has 1:1 parity with the newer callback_on_step_end method~~. I also included a deprecation warning and an error if both are used at the same time.
I tested both in my app and was able to get latent callbacks for previews and ~~interrupt the process still~~(only with the newer method). ~~If needed, I can add some kind of message warning if the callback(self, step_idx, t, latents) line toward the very end of the code detects only three inputs, instead of the four, as a hint that people just need to add a self or something else like that to their def somefunction(self, i, t, latents): function that they use for their callbacks.~~
For struckout stuff, read my comment below. I reverted the old callback to callback(steps_idx, t, latents) again.
@yiyixuxu @sayakpaul I reverted the negative_prompt changes and fully updated my original post to be more clear about the changes.
Since the legacy implementation of callback doesn't appear to have the ability to interrupt, I'm just going to roll this back. If people want to interrupt, they should be using the newer method anyways, since the older callback method is being deprecated.
The legacy callback will still function the same as before callback(step_idx, t, latents)
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
@RandomGitUser321 apologies for the delay on our end. But would love to come to the PR. What is blocking for this PR currently? How can we help?
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.