diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

Flux2: Tensor tuples can cause issues for checkpointing

Open dxqb opened this issue 1 month ago • 9 comments

Describe the bug

The modulations calculated here... https://github.com/huggingface/diffusers/blob/edf36f5128abf3e6ecf92b5145115514363c58e6/src/diffusers/models/transformers/transformer_flux2.py#L716

...return tuples of Tensors: https://github.com/huggingface/diffusers/blob/edf36f5128abf3e6ecf92b5145115514363c58e6/src/diffusers/models/transformers/transformer_flux2.py#L628

These tuples are passed from outside the transformer blocks into the checkpointed transformer blocks. If the tensors inside the tuples require gradients, this can cause issues for the backward pass:

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

torch checkpointing doesn't identify the tuples as tensors. Only tensors are identified: https://github.com/pytorch/pytorch/blob/d38164a545b4a4e4e0cf73ce67173f70574890b6/torch/utils/checkpoint.py#L252

Reproduction

isolated reproduction code is difficult because of the size of the model. but I'll post a draft PR in a minute.

Logs


System Info

torch 2.8, diffusers HEAD

Who can help?

@DN6 @yiyixuxu @sayakpaul

dxqb avatar Dec 02 '25 17:12 dxqb

Thanks for the investigation! Wonder why the gradient checkpointing tests didn't catch it up. I just ran pytest tests/models/transformers/test_models_transformer_flux2.py -k "gradient_checkpointing" and all three tests passed successfully.

sayakpaul avatar Dec 03 '25 07:12 sayakpaul

not sure. does this test actually run a backward? asking because that would take 128 GB vram at bf16 for parameters and gradients.

dxqb avatar Dec 03 '25 09:12 dxqb

It does: https://github.com/huggingface/diffusers/blob/759ea587082aa0e77449952d8f3523f28ddc61f3/tests/models/test_modeling_common.py#L1039

does this test actually run a backward? asking because that would take 128 GB vram at bf16 for parameters and gradients.

We are using a smaller variant of the model (but representative enough) so memory requirements aren't that much.

@linoytsaban and I also tried gradient checkpointing in our Flux.2 training script and it didn't cause any issues.

sayakpaul avatar Dec 03 '25 09:12 sayakpaul

I see a few potential workarounds:

  1. Flatten the tuple before passing to checkpointed blocks and reconstruct inside
  2. Create a custom tensor-like wrapper class that checkpoint can recognize
  3. Modify the modulation to return a concatenated tensor instead of a tuple.

Aznix07 avatar Dec 04 '25 12:12 Aznix07

I see a few potential workarounds:

1. Flatten the tuple before passing to checkpointed blocks and reconstruct inside

2. Create a custom tensor-like wrapper class that checkpoint can recognize

3. Modify the modulation to return a concatenated tensor instead of a tuple.

here is a simple draft that works. needs to be cleaned up. I'll do that if you agree with the approach https://github.com/huggingface/diffusers/pull/12777

dxqb avatar Dec 04 '25 13:12 dxqb

Nice solution! Moving the tensor splitting inside the checkpointed blocks is elegant - it keeps the tuple structure while avoiding the checkpoint detection issue.

I noticed you mentioned the splitting might not be necessary anymore since they're used immediately. Would flattening the modulation output entirely be simpler, or does keeping the tuple structure have other benefits?

Aznix07 avatar Dec 04 '25 17:12 Aznix07

thanks for the PR! can you proivide a small script to reproduce it? asking because @sayakpaul seems not be able to

yiyixuxu avatar Dec 04 '25 20:12 yiyixuxu

thanks for the PR! can you proivide a small script to reproduce it? asking because @sayakpaul seems not be able to

unfortunately I can't provide an isolated script because of the size of this model. you need offloading and a fused backpass (to avoid gradient buffer) to even run a backward. or a LoRA adapter in training that also runs gradients through these layers.

I think this PR is theoretically sound enough that it can be accepted without reproduction code. Or, you could point me to the the smaller model variant that @sayakpaul was talking about above, and I'll try reproduction with that.

dxqb avatar Dec 04 '25 21:12 dxqb

Here is the init dict we use for our Flux2 tests: https://github.com/huggingface/diffusers/blob/8d415a6f481ff1b26168c046267628419650f930/tests/models/transformers/test_models_transformer_flux2.py#L85

Does that work?

sayakpaul avatar Dec 04 '25 22:12 sayakpaul