diffusers
diffusers copied to clipboard
Flux2: Tensor tuples can cause issues for checkpointing
addresses https://github.com/huggingface/diffusers/issues/12776
What does this PR do?
This PR keeps the tuples, but moves the splitting from tensors into tuples of tensors to the transformer blocks, to avoid issues with checkpointing. By passing a tensor directly, torch.utils.checkpoint() identifies the tensor and saves it accordingly without running a backward through it multiple times.
This is a draft. If you agree with this change I can make it nicer. Among other things:
- type hints are incorrect
- splitting might not be necessary anymore, because they are used immediately after
Who can review?
@yiyixuxu and @asomoza