Flux2: Tensor tuples can cause issues for checkpointing
Describe the bug
The modulations calculated here... https://github.com/huggingface/diffusers/blob/edf36f5128abf3e6ecf92b5145115514363c58e6/src/diffusers/models/transformers/transformer_flux2.py#L716
...return tuples of Tensors: https://github.com/huggingface/diffusers/blob/edf36f5128abf3e6ecf92b5145115514363c58e6/src/diffusers/models/transformers/transformer_flux2.py#L628
These tuples are passed from outside the transformer blocks into the checkpointed transformer blocks. If the tensors inside the tuples require gradients, this can cause issues for the backward pass:
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
torch checkpointing doesn't identify the tuples as tensors. Only tensors are identified: https://github.com/pytorch/pytorch/blob/d38164a545b4a4e4e0cf73ce67173f70574890b6/torch/utils/checkpoint.py#L252
Reproduction
isolated reproduction code is difficult because of the size of the model. but I'll post a draft PR in a minute.
Logs
System Info
torch 2.8, diffusers HEAD
Who can help?
@DN6 @yiyixuxu @sayakpaul
Thanks for the investigation! Wonder why the gradient checkpointing tests didn't catch it up. I just ran pytest tests/models/transformers/test_models_transformer_flux2.py -k "gradient_checkpointing" and all three tests passed successfully.
not sure. does this test actually run a backward? asking because that would take 128 GB vram at bf16 for parameters and gradients.
It does: https://github.com/huggingface/diffusers/blob/759ea587082aa0e77449952d8f3523f28ddc61f3/tests/models/test_modeling_common.py#L1039
does this test actually run a backward? asking because that would take 128 GB vram at bf16 for parameters and gradients.
We are using a smaller variant of the model (but representative enough) so memory requirements aren't that much.
@linoytsaban and I also tried gradient checkpointing in our Flux.2 training script and it didn't cause any issues.
I see a few potential workarounds:
- Flatten the tuple before passing to checkpointed blocks and reconstruct inside
- Create a custom tensor-like wrapper class that checkpoint can recognize
- Modify the modulation to return a concatenated tensor instead of a tuple.
I see a few potential workarounds:
1. Flatten the tuple before passing to checkpointed blocks and reconstruct inside 2. Create a custom tensor-like wrapper class that checkpoint can recognize 3. Modify the modulation to return a concatenated tensor instead of a tuple.
here is a simple draft that works. needs to be cleaned up. I'll do that if you agree with the approach https://github.com/huggingface/diffusers/pull/12777
Nice solution! Moving the tensor splitting inside the checkpointed blocks is elegant - it keeps the tuple structure while avoiding the checkpoint detection issue.
I noticed you mentioned the splitting might not be necessary anymore since they're used immediately. Would flattening the modulation output entirely be simpler, or does keeping the tuple structure have other benefits?
thanks for the PR! can you proivide a small script to reproduce it? asking because @sayakpaul seems not be able to
thanks for the PR! can you proivide a small script to reproduce it? asking because @sayakpaul seems not be able to
unfortunately I can't provide an isolated script because of the size of this model. you need offloading and a fused backpass (to avoid gradient buffer) to even run a backward. or a LoRA adapter in training that also runs gradients through these layers.
I think this PR is theoretically sound enough that it can be accepted without reproduction code. Or, you could point me to the the smaller model variant that @sayakpaul was talking about above, and I'll try reproduction with that.
Here is the init dict we use for our Flux2 tests: https://github.com/huggingface/diffusers/blob/8d415a6f481ff1b26168c046267628419650f930/tests/models/transformers/test_models_transformer_flux2.py#L85
Does that work?