diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

Flux2: Tensor tuples can cause issues for checkpointing

Open dxqb opened this issue 1 month ago • 0 comments

addresses https://github.com/huggingface/diffusers/issues/12776

What does this PR do?

This PR keeps the tuples, but moves the splitting from tensors into tuples of tensors to the transformer blocks, to avoid issues with checkpointing. By passing a tensor directly, torch.utils.checkpoint() identifies the tensor and saves it accordingly without running a backward through it multiple times.

This is a draft. If you agree with this change I can make it nicer. Among other things:

  • type hints are incorrect
  • splitting might not be necessary anymore, because they are used immediately after

Who can review?

@yiyixuxu and @asomoza

dxqb avatar Dec 02 '25 17:12 dxqb