diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

Complex numbers in Qwen/Z-Image Image pipeline incompatible with torch.compile

Open dxqb opened this issue 1 month ago • 10 comments

Describe the bug

Note: This might be something for the MVP program https://github.com/huggingface/diffusers/issues/12635 if there's anyone who already has a deep understanding of rotary embeddings and complex numbers. I don't.

The Qwen image pipeline calls https://github.com/huggingface/diffusers/blob/01a56927f1603f1e89d1e5ada74d2aa75da2d46b/src/diffusers/models/transformers/transformer_qwenimage.py#L96 with use_real==False.

The function therefore operates on complex numbers. If compiled, torch.compile warns about this: venv/lib/python3.12/site-packages/torch/_inductor/lowering.py:1890: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.

Performance being worse than eager isn't a big deal. This is not a performance critical part of the model. However, due to a subtle torch.compile bug it leads to random compile failures: https://github.com/pytorch/pytorch/issues/163876

Can the code path with real numbers be used instead?

Reproduction

I cannot provide reproduction code, because it's random and shows up mostly when a kernel is recompiled, but also not consistently. Multiple users are affected though. It can be worked around by putting a compile.disable decorator around the function, but I don't like this solution because then you cannot compile with fullgraph=True anymore.

Logs

packed_predicted_flow = model.transformer(
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/nn/
modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/linux/KI/OneTrainer/src/diffusers/src/diffusers/models/transformers/transformer_qwenimage.py", line 629, in forward
    encoder_hidden_states, hidden_states = block(
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1771, in _wrapped_call_impl
    return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 749, in compile_wrapper
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1871, in _call_user_compiler
    raise BackendCompilerFailed(
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1846, in _call_user_compiler
    compiled_fn = compiler_fn(gm, example_inputs)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_dynamo/repro/after_dynamo.py", line 150, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/__init__.py", line 2380, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 2418, in compile_fx
    return aot_autograd(
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 109, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1199, in aot_module_simplified
    compiled_fn = AOTAutogradCache.load(
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/autograd_cache.py", line 1140, in load
    compiled_fn = dispatch_and_compile()
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1184, in dispatch_and_compile
    compiled_fn, _ = create_aot_dispatcher_function(
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 576, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 836, in _create_aot_dispatcher_function
    compiled_fn, fw_metadata = compiler_fn(
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 1604, in aot_dispatch_autograd
    compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 483, in __call__
    return self.compiler_fn(gm, example_inputs)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 2250, in fw_compiler_base
    return inner_compile(
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 745, in compile_fx_inner
    return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_dynamo/repro/after_aot.py", line 124, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 896, in _compile_fx_inner
    mb_compiled_graph = fx_codegen_and_compile(
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1578, in fx_codegen_and_compile
    return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1456, in codegen_and_compile
    compiled_module = graph.compile_to_module()
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/graph.py", line 2293, in compile_to_module
    return self._compile_to_module()
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/graph.py", line 2299, in _compile_to_module
    self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/graph.py", line 2238, in codegen
    self.scheduler.codegen()
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/scheduler.py", line 4598, in codegen
    else self._codegen(self.nodes)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/scheduler.py", line 4750, in _codegen
    self.get_backend(device).codegen_node(node)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/codegen/cuda_combined_scheduling.py", line 107, in codegen_node
    return self._triton_scheduling.codegen_node(node)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/codegen/simd.py", line 1363, in codegen_node
    coalesce_analysis = analyze_memory_coalescing(node)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/tiling_utils.py", line 650, in analyze_memory_coalescing
    norm_read_writes = extract_normalized_read_writes(fused_node)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/tiling_utils.py", line 482, in extract_normalized_read_writes
    if any(
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/tiling_utils.py", line 483, in <genexpr>
    (isinstance(var, sympy.Expr) and not var.is_constant())
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/expr.py", line 724, in is_constant
    b = expr._random(None, -1, 0, 1, 0)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/expr.py", line 562, in _random
    nmag = abs(self.evalf(2, subs=reps))
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/evalf.py", line 1654, in evalf
    result = evalf(self, prec + 4, options)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/evalf.py", line 1489, in evalf
    r = rf(x, prec, options)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/evalf.py", line 602, in evalf_add
    terms = [evalf(arg, prec + 10, options) for arg in v.args]
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/evalf.py", line 602, in <listcomp>
    terms = [evalf(arg, prec + 10, options) for arg in v.args]
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/evalf.py", line 1489, in evalf
    r = rf(x, prec, options)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/evalf.py", line 650, in evalf_mul
    result = evalf(arg, prec, options)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/evalf.py", line 1493, in evalf
    x = x.subs(evalf_subs(prec, options['subs']))
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/basic.py", line 1171, in subs
    rv = rv._subs(old, new, **kwargs)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/cache.py", line 72, in wrapper
    retval = cfunc(*args, **kwargs)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/basic.py", line 1285, in _subs
    rv = fallback(self, old, new)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/basic.py", line 1262, in fallback
    rv = self.func(*args)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/cache.py", line 72, in wrapper
    retval = cfunc(*args, **kwargs)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/function.py", line 450, in __new__
    return cls._new_(*args, **options)  # type: ignore
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/function.py", line 472, in _new_
    result = super().__new__(cls, *args, **options)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/cache.py", line 72, in wrapper
    retval = cfunc(*args, **kwargs)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/function.py", line 309, in __new__
    evaluated = cls.eval(*args)
  File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/utils/_sympy/functions.py", line 488, in eval
    assert p >= 0, p
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AssertionError: -1470286036225387/1000000000000000

System Info

various with torch 2.8

Who can help?

@DN6 @yiyixuxu @sayakpaul

dxqb avatar Nov 15 '25 17:11 dxqb

Yes, this is a bit of a known issue. We did try a few things in the past and for this model, it so happens that regional compilation provides similar gains as full model compilation in terms of performance (even with fullgraph=True). Regional compilation avoids compiling the RoPE block and hence, we side step this issue completely, while not compromising with the speed benefits at all.

Docs: https://huggingface.co/docs/diffusers/main/en/optimization/torch2.0#regional-compilation

Ccing @anijain2305 and @StrongerXI as well.

sayakpaul avatar Nov 15 '25 17:11 sayakpaul

I should have mentioned, this is with only compiling the transformer blocks (manually via transformer.transformer_block[i].compile(fullgraph=True))

I agree that fullgraph=False is basically the same performance, but only if you already have production code that can be compiled except what you don't need compiled. With fullgraph=False, you can unintentionally introduce graph breaks during development without noticing.

dxqb avatar Nov 15 '25 17:11 dxqb

I am saying if you do pipe.transformer.compile_repeated_blocks(fullgraph=True), it should work as expected, providing similar speedup benefits as full model compilation.

This is how we test it: https://github.com/huggingface/diffusers/blob/01a56927f1603f1e89d1e5ada74d2aa75da2d46b/tests/models/test_modeling_common.py#L2076

https://github.com/huggingface/diffusers/blob/01a56927f1603f1e89d1e5ada74d2aa75da2d46b/tests/models/transformers/test_models_transformer_qwenimage.py#L95

What am I missing?

sayakpaul avatar Nov 15 '25 17:11 sayakpaul

I am saying if you do pipe.transformer.compile_repeated_blocks(fullgraph=True), it should work as expected, providing similar speedup benefits as full model compilation.

had a closer look now. compile_repeated_block compiles only the transformer blocks (which I was also doing). But the complex numbers are used from inside the transformer block.

the embedding is created outside the transformer block: https://github.com/huggingface/diffusers/blob/01a56927f1603f1e89d1e5ada74d2aa75da2d46b/src/diffusers/models/transformers/transformer_qwenimage.py#L633

but used inside: https://github.com/huggingface/diffusers/blob/01a56927f1603f1e89d1e5ada74d2aa75da2d46b/src/diffusers/models/transformers/transformer_qwenimage.py#L320

So I don't think that's a workround for this issue.

dxqb avatar Nov 16 '25 10:11 dxqb

Yeah you're right. I did:

from diffusers import DiffusionPipeline 
import torch 

pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16).to("cuda")
pipe.transformer.compile_repeated_blocks(fullgraph=True)

with (
    torch._inductor.utils.fresh_inductor_cache(),
    torch._dynamo.config.patch(error_on_recompile=True),
):
    _ = pipe("a dog", negative_prompt=" ", true_cfg_scale=5.0, num_inference_steps=5)

Can the code path with real numbers be used instead?

What did you have in mind when you suggested this?

sayakpaul avatar Nov 16 '25 10:11 sayakpaul

Can the code path with real numbers be used instead?

What did you have in mind when you suggested this?

I don't have a good theoretical background on this, but I suspect the same can be implemented without using complex numbers. Convert the embedding to real numbers and then call apply_rotary_emb_qwen with use_real=True?

dxqb avatar Nov 16 '25 10:11 dxqb

I don't think that's possible if the training didn't do something similar but maybe I am also not fully through with that.

sayakpaul avatar Nov 16 '25 11:11 sayakpaul

it's not doable without retraining for a bit, that's out of distribution

bghira avatar Nov 18 '25 13:11 bghira

using @torch.compiler.disable isn't a workaround either even with fullgraph=False. Issue still happens, I guess caused simply by passing complex numbers through a compiled region, even if all calculations are outside of the compiled region

dxqb avatar Nov 18 '25 15:11 dxqb

Z-Image also affected

dxqb avatar Dec 05 '25 20:12 dxqb