Complex numbers in Qwen/Z-Image Image pipeline incompatible with torch.compile
Describe the bug
Note: This might be something for the MVP program https://github.com/huggingface/diffusers/issues/12635 if there's anyone who already has a deep understanding of rotary embeddings and complex numbers. I don't.
The Qwen image pipeline calls https://github.com/huggingface/diffusers/blob/01a56927f1603f1e89d1e5ada74d2aa75da2d46b/src/diffusers/models/transformers/transformer_qwenimage.py#L96 with use_real==False.
The function therefore operates on complex numbers.
If compiled, torch.compile warns about this: venv/lib/python3.12/site-packages/torch/_inductor/lowering.py:1890: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
Performance being worse than eager isn't a big deal. This is not a performance critical part of the model. However, due to a subtle torch.compile bug it leads to random compile failures: https://github.com/pytorch/pytorch/issues/163876
Can the code path with real numbers be used instead?
Reproduction
I cannot provide reproduction code, because it's random and shows up mostly when a kernel is recompiled, but also not consistently.
Multiple users are affected though. It can be worked around by putting a compile.disable decorator around the function, but I don't like this solution because then you cannot compile with fullgraph=True anymore.
Logs
packed_predicted_flow = model.transformer(
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/nn/
modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "/home/linux/KI/OneTrainer/src/diffusers/src/diffusers/models/transformers/transformer_qwenimage.py", line 629, in forward
encoder_hidden_states, hidden_states = block(
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1771, in _wrapped_call_impl
return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 749, in compile_wrapper
raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1871, in _call_user_compiler
raise BackendCompilerFailed(
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1846, in _call_user_compiler
compiled_fn = compiler_fn(gm, example_inputs)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_dynamo/repro/after_dynamo.py", line 150, in __call__
compiled_gm = compiler_fn(gm, example_inputs)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/__init__.py", line 2380, in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 2418, in compile_fx
return aot_autograd(
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 109, in __call__
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1199, in aot_module_simplified
compiled_fn = AOTAutogradCache.load(
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/autograd_cache.py", line 1140, in load
compiled_fn = dispatch_and_compile()
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1184, in dispatch_and_compile
compiled_fn, _ = create_aot_dispatcher_function(
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 576, in create_aot_dispatcher_function
return _create_aot_dispatcher_function(
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 836, in _create_aot_dispatcher_function
compiled_fn, fw_metadata = compiler_fn(
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 1604, in aot_dispatch_autograd
compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 483, in __call__
return self.compiler_fn(gm, example_inputs)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 2250, in fw_compiler_base
return inner_compile(
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 745, in compile_fx_inner
return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_dynamo/repro/after_aot.py", line 124, in debug_wrapper
inner_compiled_fn = compiler_fn(gm, example_inputs)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 896, in _compile_fx_inner
mb_compiled_graph = fx_codegen_and_compile(
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1578, in fx_codegen_and_compile
return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1456, in codegen_and_compile
compiled_module = graph.compile_to_module()
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/graph.py", line 2293, in compile_to_module
return self._compile_to_module()
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/graph.py", line 2299, in _compile_to_module
self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/graph.py", line 2238, in codegen
self.scheduler.codegen()
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/scheduler.py", line 4598, in codegen
else self._codegen(self.nodes)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/scheduler.py", line 4750, in _codegen
self.get_backend(device).codegen_node(node)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/codegen/cuda_combined_scheduling.py", line 107, in codegen_node
return self._triton_scheduling.codegen_node(node)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/codegen/simd.py", line 1363, in codegen_node
coalesce_analysis = analyze_memory_coalescing(node)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/tiling_utils.py", line 650, in analyze_memory_coalescing
norm_read_writes = extract_normalized_read_writes(fused_node)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/tiling_utils.py", line 482, in extract_normalized_read_writes
if any(
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/_inductor/tiling_utils.py", line 483, in <genexpr>
(isinstance(var, sympy.Expr) and not var.is_constant())
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/expr.py", line 724, in is_constant
b = expr._random(None, -1, 0, 1, 0)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/expr.py", line 562, in _random
nmag = abs(self.evalf(2, subs=reps))
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/evalf.py", line 1654, in evalf
result = evalf(self, prec + 4, options)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/evalf.py", line 1489, in evalf
r = rf(x, prec, options)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/evalf.py", line 602, in evalf_add
terms = [evalf(arg, prec + 10, options) for arg in v.args]
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/evalf.py", line 602, in <listcomp>
terms = [evalf(arg, prec + 10, options) for arg in v.args]
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/evalf.py", line 1489, in evalf
r = rf(x, prec, options)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/evalf.py", line 650, in evalf_mul
result = evalf(arg, prec, options)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/evalf.py", line 1493, in evalf
x = x.subs(evalf_subs(prec, options['subs']))
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/basic.py", line 1171, in subs
rv = rv._subs(old, new, **kwargs)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/cache.py", line 72, in wrapper
retval = cfunc(*args, **kwargs)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/basic.py", line 1285, in _subs
rv = fallback(self, old, new)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/basic.py", line 1262, in fallback
rv = self.func(*args)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/cache.py", line 72, in wrapper
retval = cfunc(*args, **kwargs)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/function.py", line 450, in __new__
return cls._new_(*args, **options) # type: ignore
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/function.py", line 472, in _new_
result = super().__new__(cls, *args, **options)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/cache.py", line 72, in wrapper
retval = cfunc(*args, **kwargs)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/sympy/core/function.py", line 309, in __new__
evaluated = cls.eval(*args)
File "/home/linux/KI/OneTrainer/conda_env/lib/python3.10/site-packages/torch/utils/_sympy/functions.py", line 488, in eval
assert p >= 0, p
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AssertionError: -1470286036225387/1000000000000000
System Info
various with torch 2.8
Who can help?
@DN6 @yiyixuxu @sayakpaul
Yes, this is a bit of a known issue. We did try a few things in the past and for this model, it so happens that regional compilation provides similar gains as full model compilation in terms of performance (even with fullgraph=True). Regional compilation avoids compiling the RoPE block and hence, we side step this issue completely, while not compromising with the speed benefits at all.
Docs: https://huggingface.co/docs/diffusers/main/en/optimization/torch2.0#regional-compilation
Ccing @anijain2305 and @StrongerXI as well.
I should have mentioned, this is with only compiling the transformer blocks (manually via transformer.transformer_block[i].compile(fullgraph=True))
I agree that fullgraph=False is basically the same performance, but only if you already have production code that can be compiled except what you don't need compiled. With fullgraph=False, you can unintentionally introduce graph breaks during development without noticing.
I am saying if you do pipe.transformer.compile_repeated_blocks(fullgraph=True), it should work as expected, providing similar speedup benefits as full model compilation.
This is how we test it: https://github.com/huggingface/diffusers/blob/01a56927f1603f1e89d1e5ada74d2aa75da2d46b/tests/models/test_modeling_common.py#L2076
https://github.com/huggingface/diffusers/blob/01a56927f1603f1e89d1e5ada74d2aa75da2d46b/tests/models/transformers/test_models_transformer_qwenimage.py#L95
What am I missing?
I am saying if you do
pipe.transformer.compile_repeated_blocks(fullgraph=True), it should work as expected, providing similar speedup benefits as full model compilation.
had a closer look now. compile_repeated_block compiles only the transformer blocks (which I was also doing). But the complex numbers are used from inside the transformer block.
the embedding is created outside the transformer block: https://github.com/huggingface/diffusers/blob/01a56927f1603f1e89d1e5ada74d2aa75da2d46b/src/diffusers/models/transformers/transformer_qwenimage.py#L633
but used inside: https://github.com/huggingface/diffusers/blob/01a56927f1603f1e89d1e5ada74d2aa75da2d46b/src/diffusers/models/transformers/transformer_qwenimage.py#L320
So I don't think that's a workround for this issue.
Yeah you're right. I did:
from diffusers import DiffusionPipeline
import torch
pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16).to("cuda")
pipe.transformer.compile_repeated_blocks(fullgraph=True)
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=True),
):
_ = pipe("a dog", negative_prompt=" ", true_cfg_scale=5.0, num_inference_steps=5)
Can the code path with real numbers be used instead?
What did you have in mind when you suggested this?
Can the code path with real numbers be used instead?
What did you have in mind when you suggested this?
I don't have a good theoretical background on this, but I suspect the same can be implemented without using complex numbers. Convert the embedding to real numbers and then call apply_rotary_emb_qwen with use_real=True?
I don't think that's possible if the training didn't do something similar but maybe I am also not fully through with that.
it's not doable without retraining for a bit, that's out of distribution
using @torch.compiler.disable isn't a workaround either even with fullgraph=False. Issue still happens, I guess caused simply by passing complex numbers through a compiled region, even if all calculations are outside of the compiled region
Z-Image also affected