Error when training the proposed method
I follow the steps shown in the repo to set up the conda environment, and then execute distill.sh on A100. It raises the following error:
Traceback (most recent call last):
[rank0]: File "/root/CLEAR/distill.py", line 1242, in <module>
[rank0]: main(args)
[rank0]: File "/root/CLEAR/distill.py", line 1025, in main
[rank0]: teacher_pred = transformer_teacher(
[rank0]: ^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/diffusers/models/transformers/transformer_flux.py", line 522, in forward
[rank0]: encoder_hidden_states, hidden_states = block(
[rank0]: ^^^^^^
[rank0]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/diffusers/models/transformers/transformer_flux.py", line 180, in forward
[rank0]: attention_outputs = self.attn(
[rank0]: ^^^^^^^^^^
[rank0]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/diffusers/models/attention_processor.py", line 588, in forward
[rank0]: return self.processor(
[rank0]: ^^^^^^^^^^^^^^^
[rank0]: File "/root/CLEAR/attention_processor.py", line 84, in __call__
[rank0]: query = apply_rotary_emb(query, image_rotary_emb)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/diffusers/models/embeddings.py", line 1208, in apply_rotary_emb
[rank0]: out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
[rank0]: ~~~~~~~~~~^~~~~
[rank0]: RuntimeError: The size of tensor a (4608) must match the size of tensor b (16896) at non-singleton dimension 2
Here is my environment info:
# Name Version Build Channel
_libgcc_mutex 0.1 main
_openmp_mutex 5.1 1_gnu
accelerate 1.3.0 pypi_0 pypi
annotated-types 0.7.0 pypi_0 pypi
bzip2 1.0.8 h5eee18b_6
ca-certificates 2024.12.31 h06a4308_0
certifi 2025.1.31 pypi_0 pypi
charset-normalizer 3.4.1 pypi_0 pypi
click 8.1.8 pypi_0 pypi
contourpy 1.3.1 pypi_0 pypi
cycler 0.12.1 pypi_0 pypi
deepspeed 0.16.3 pypi_0 pypi
diffusers 0.32.2 pypi_0 pypi
docker-pycreds 0.4.0 pypi_0 pypi
einops 0.8.1 pypi_0 pypi
expat 2.6.4 h6a678d5_0
filelock 3.17.0 pypi_0 pypi
fonttools 4.56.0 pypi_0 pypi
fsspec 2025.2.0 pypi_0 pypi
gitdb 4.0.12 pypi_0 pypi
gitpython 3.1.44 pypi_0 pypi
hjson 3.1.0 pypi_0 pypi
huggingface-hub 0.28.1 pypi_0 pypi
idna 3.10 pypi_0 pypi
importlib-metadata 8.6.1 pypi_0 pypi
jinja2 3.1.5 pypi_0 pypi
kiwisolver 1.4.8 pypi_0 pypi
ld_impl_linux-64 2.40 h12ee557_0
libffi 3.4.4 h6a678d5_1
libgcc-ng 11.2.0 h1234567_1
libgomp 11.2.0 h1234567_1
libstdcxx-ng 11.2.0 h1234567_1
libuuid 1.41.5 h5eee18b_0
markupsafe 3.0.2 pypi_0 pypi
matplotlib 3.10.0 pypi_0 pypi
mpmath 1.3.0 pypi_0 pypi
msgpack 1.1.0 pypi_0 pypi
ncurses 6.4 h6a678d5_0
networkx 3.4.2 pypi_0 pypi
ninja 1.11.1.3 pypi_0 pypi
numpy 2.1.1 pypi_0 pypi
nvidia-cublas-cu12 12.4.5.8 pypi_0 pypi
nvidia-cuda-cupti-cu12 12.4.127 pypi_0 pypi
nvidia-cuda-nvrtc-cu12 12.4.127 pypi_0 pypi
nvidia-cuda-runtime-cu12 12.4.127 pypi_0 pypi
nvidia-cudnn-cu12 9.1.0.70 pypi_0 pypi
nvidia-cufft-cu12 11.2.1.3 pypi_0 pypi
nvidia-curand-cu12 10.3.5.147 pypi_0 pypi
nvidia-cusolver-cu12 11.6.1.9 pypi_0 pypi
nvidia-cusparse-cu12 12.3.1.170 pypi_0 pypi
nvidia-cusparselt-cu12 0.6.2 pypi_0 pypi
nvidia-nccl-cu12 2.21.5 pypi_0 pypi
nvidia-nvjitlink-cu12 12.4.127 pypi_0 pypi
nvidia-nvtx-cu12 12.4.127 pypi_0 pypi
opencv-python 4.11.0.86 pypi_0 pypi
opencv-python-headless 4.11.0.86 pypi_0 pypi
openssl 3.0.15 h5eee18b_0
packaging 24.2 pypi_0 pypi
pandas 2.2.3 pypi_0 pypi
pillow 11.1.0 pypi_0 pypi
pip 25.0 py312h06a4308_0
platformdirs 4.3.6 pypi_0 pypi
prodigyopt 1.1.2 pypi_0 pypi
protobuf 5.29.3 pypi_0 pypi
psutil 7.0.0 pypi_0 pypi
py-cpuinfo 9.0.0 pypi_0 pypi
pydantic 2.10.6 pypi_0 pypi
pydantic-core 2.27.2 pypi_0 pypi
pyparsing 3.2.1 pypi_0 pypi
python 3.12.9 h5148396_0
python-dateutil 2.9.0.post0 pypi_0 pypi
pytz 2025.1 pypi_0 pypi
pyyaml 6.0.2 pypi_0 pypi
readline 8.2 h5eee18b_0
regex 2024.11.6 pypi_0 pypi
requests 2.32.3 pypi_0 pypi
safetensors 0.5.2 pypi_0 pypi
scipy 1.15.1 pypi_0 pypi
seaborn 0.13.2 pypi_0 pypi
sentencepiece 0.2.0 pypi_0 pypi
sentry-sdk 2.21.0 pypi_0 pypi
setproctitle 1.3.4 pypi_0 pypi
setuptools 75.8.0 py312h06a4308_0
six 1.17.0 pypi_0 pypi
smmap 5.0.2 pypi_0 pypi
sqlite 3.45.3 h5eee18b_0
sympy 1.13.1 pypi_0 pypi
tk 8.6.14 h39e8969_0
tokenizers 0.21.0 pypi_0 pypi
torch 2.6.0 pypi_0 pypi
torchvision 0.21.0 pypi_0 pypi
tqdm 4.67.1 pypi_0 pypi
transformers 4.48.3 pypi_0 pypi
triton 3.2.0 pypi_0 pypi
typing-extensions 4.12.2 pypi_0 pypi
tzdata 2025.1 pypi_0 pypi
ultralytics 8.3.75 pypi_0 pypi
ultralytics-thop 2.0.14 pypi_0 pypi
urllib3 2.3.0 pypi_0 pypi
wandb 0.19.6 pypi_0 pypi
wheel 0.45.1 py312h06a4308_0
xz 5.6.4 h5eee18b_1
zipp 3.21.0 pypi_0 pypi
zlib 1.2.13 h5eee18b_1
Hi,
I speculate the problem is on the diffusers' version. We use 0.31.0 in our experiments.
Hi @Huage001
Thanks for the quick reply. I downgrade diffusers according to your suggestion but it raises another error:
Traceback (most recent call last):
[rank0]: File "/root/xxx/CLEAR/distill.py", line 1242, in <module>
[rank0]: main(args)
[rank0]: File "/root/xxx/CLEAR/distill.py", line 1076, in main
[rank0]: accelerator.backward(loss)
[rank0]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/accelerate/accelerator.py", line 2238, in backward
[rank0]: self.deepspeed_engine_wrapped.backward(loss, **kwargs)
[rank0]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/accelerate/utils/deepspeed.py", line 261, in backward
[rank0]: self.engine.backward(loss, **kwargs)
[rank0]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
[rank0]: ret_val = func(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/deepspeed/runtime/engine.py", line 2053, in backward
[rank0]: self.optimizer.backward(loss, retain_graph=retain_graph)
[rank0]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 2062, in backward
[rank0]: self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
[rank0]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/deepspeed/runtime/fp16/loss_scaler.py", line 63, in backward
[rank0]: scaled_loss.backward(retain_graph=retain_graph)
[rank0]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_tensor.py", line 626, in backward
[rank0]: torch.autograd.backward(
[rank0]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/autograd/__init__.py", line 347, in backward
[rank0]: _engine_run_backward(
[rank0]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/autograd/graph.py", line 823, in _engine_run_backward
[rank0]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/autograd/function.py", line 307, in apply
[rank0]: return user_fn(self, *args)
[rank0]: ^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1710, in backward
[rank0]: return impl_fn()
[rank0]: ^^^^^^^^^
[rank0]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1700, in impl_fn
[rank0]: out = CompiledFunction._backward_impl(ctx, all_args)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2037, in _backward_impl
[rank0]: CompiledFunction.compiled_bw = aot_config.bw_compiler(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 489, in __call__
[rank0]: return self.compiler_fn(gm, example_inputs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/backends/common.py", line 54, in _wrapped_bw_compiler
[rank0]: return disable(disable(bw_compiler_fn)(*args, **kwargs))
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
[rank0]: return fn(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_utils_internal.py", line 95, in wrapper_function
[rank0]: return function(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1808, in bw_compiler
[rank0]: return inner_compile(
[rank0]: ^^^^^^^^^^^^^^
[rank0]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 569, in compile_fx_inner
[rank0]: return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/repro/after_aot.py", line 102, in debug_wrapper
[rank0]: inner_compiled_fn = compiler_fn(gm, example_inputs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 675, in _compile_fx_inner
[rank0]: mb_compiled_graph = fx_codegen_and_compile(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1129, in fx_codegen_and_compile
[rank0]: return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 979, in codegen_and_compile
[rank0]: graph.run(*example_inputs)
[rank0]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_inductor/graph.py", line 855, in run
[rank0]: return super().run(*args)
[rank0]: ^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/fx/interpreter.py", line 167, in run
[rank0]: self.env[node] = self.run_node(node)
[rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_inductor/graph.py", line 1496, in run_node
[rank0]: result = super().run_node(n)
[rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/fx/interpreter.py", line 230, in run_node
[rank0]: return getattr(self, n.op)(n.target, args, kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_inductor/graph.py", line 1143, in call_function
[rank0]: raise LoweringException(e, target, args, kwargs).with_traceback(
[rank0]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_inductor/graph.py", line 1133, in call_function
[rank0]: out = lowerings[target](*args, **kwargs) # type: ignore[index]
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_inductor/lowering.py", line 409, in wrapped
[rank0]: out = decomp_fn(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_inductor/kernel/flex_attention.py", line 2361, in flex_attention_backward
[rank0]: broadcasted_grad_key = autotune_select_algorithm(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_inductor/select_algorithm.py", line 1909, in autotune_select_algorithm
[rank0]: return _ALGORITHM_SELECTOR_CACHE(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_inductor/select_algorithm.py", line 1379, in __call__
[rank0]: raise NoValidChoicesError(
[rank0]: torch._inductor.exc.LoweringException: NoValidChoicesError: No choices to select, please consider adding ATEN into max_autotune_gemm_backends config (defined in torch/_inductor/config.py) to allow at least one choice.
[rank0]: target: flex_attention_backward
[rank0]: args[0]: TensorBox(StorageBox(
[rank0]: InputBuffer(name='primals_1', layout=FixedLayout('cuda:0', torch.bfloat16, size=[2, 24, 4608, 128], stride=[14155776, 128, 3072, 1]))
[rank0]: ))
[rank0]: args[1]: TensorBox(StorageBox(
[rank0]: InputBuffer(name='primals_2', layout=FixedLayout('cuda:0', torch.bfloat16, size=[2, 24, 4608, 128], stride=[14155776, 128, 3072, 1]))
[rank0]: ))
[rank0]: args[2]: TensorBox(StorageBox(
[rank0]: InputBuffer(name='primals_3', layout=FixedLayout('cuda:0', torch.bfloat16, size=[2, 24, 4608, 128], stride=[14155776, 128, 3072, 1]))
[rank0]: ))
[rank0]: args[3]: TensorBox(StorageBox(
[rank0]: InputBuffer(name='getitem', layout=FixedLayout('cuda:0', torch.bfloat16, size=[2, 24, 4608, 128], stride=[14155776, 128, 3072, 1]))
[rank0]: ))
[rank0]: args[4]: TensorBox(StorageBox(
[rank0]: DonatedBuffer(name='getitem_1', layout=FixedLayout('cuda:0', torch.float32, size=[2, 24, 4608], stride=[110592, 4608, 1]))
[rank0]: ))
[rank0]: args[5]: TensorBox(StorageBox(
[rank0]: InputBuffer(name='tangents_1', layout=FixedLayout('cuda:0', torch.bfloat16, size=[2, 24, 4608, 128], stride=[14155776, 589824, 128, 1]))
[rank0]: ))
[rank0]: args[6]: TensorBox(StorageBox(
[rank0]: Pointwise(
[rank0]: 'cuda',
[rank0]: torch.float32,
[rank0]: def inner_fn(index):
[rank0]: i0, i1, i2 = index
[rank0]: tmp0 = ops.constant(0, torch.float32)
[rank0]: return tmp0
[rank0]: ,
[rank0]: ranges=[2, 24, 4608],
[rank0]: origin_node=full_default,
[rank0]: origins=OrderedSet([full_default])
[rank0]: )
[rank0]: ))
[rank0]: args[7]: Subgraph(name='fw_graph0', graph_module=<lambda>(), graph=None)
[rank0]: args[8]: Subgraph(name='joint_graph0', graph_module=<lambda>(), graph=None)
[rank0]: args[9]: (4608, 4608, TensorBox(StorageBox(
[rank0]: InputBuffer(name='primals_5', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 36], stride=[36, 36, 1]))
[rank0]: )), TensorBox(StorageBox(
[rank0]: InputBuffer(name='primals_4', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 36, 36], stride=[1296, 1296, 36, 1]))
[rank0]: )), TensorBox(StorageBox(
[rank0]: InputBuffer(name='primals_6', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 36], stride=[36, 36, 1]))
[rank0]: )), TensorBox(StorageBox(
[rank0]: InputBuffer(name='primals_7', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 36, 36], stride=[1296, 1296, 36, 1]))
[rank0]: )), TensorBox(StorageBox(
[rank0]: InputBuffer(name='primals_8', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 36], stride=[36, 36, 1]))
[rank0]: )), TensorBox(StorageBox(
[rank0]: InputBuffer(name='primals_9', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 36, 36], stride=[1296, 1296, 36, 1]))
[rank0]: )), TensorBox(StorageBox(
[rank0]: InputBuffer(name='primals_10', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 36], stride=[36, 36, 1]))
[rank0]: )), TensorBox(StorageBox(
[rank0]: InputBuffer(name='primals_11', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 36, 36], stride=[1296, 1296, 36, 1]))
[rank0]: )), 128, 128, Subgraph(name='mask_graph0', graph_module=<lambda>(), graph=None))
[rank0]: args[10]: 0.08838834764831845
[rank0]: args[11]: {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'OUTPUT_LOGSUMEXP': True}
[rank0]: args[12]: ()
[rank0]: args[13]: ()
A little strange.
From the error message, the problem is on the flex attention. Could you please try torch==2.5.0 because we experiment on that version. Maybe 2.6.0 has some incompatible updates.
Thanks, it seems that the above issue is indeed caused by torch 2.6.0. But after I switch to diffusers==0.31.0 and torch==2.5.0, it raises the following error when I run distill.sh:
[rank3]: File "/root/xxx/projects/GSPN/t2i/advanced/attention_processor.py", line 344, in __call__
[rank3]: hidden_states = self.flex_attn(query, key, value, scale=attention_scale)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
[rank3]: return fn(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^
[rank3]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1263, in __call__
[rank3]: return hijacked_callback(
[rank3]: ^^^^^^^^^^^^^^^^^^
[rank3]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1064, in __call__
[rank3]: result = self._inner_convert(
[rank3]: ^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 526, in __call__
[rank3]: return _compile(
[rank3]: ^^^^^^^^^
[rank3]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 924, in _compile
[rank3]: guarded_code = compile_inner(code, one_graph, hooks, transform)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 666, in compile_inner
[rank3]: return _compile_inner(code, one_graph, hooks, transform)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_utils_internal.py", line 87, in wrapper_function
[rank3]: return function(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 699, in _compile_inner
[rank3]: out_code = transform_code_object(code, transform)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
[rank3]: transformations(instructions, code_options)
[rank3]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 219, in _fn
[rank3]: return fn(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^
[rank3]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 634, in transform
[rank3]: tracer.run()
[rank3]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2796, in run
[rank3]: super().run()
[rank3]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
[rank3]: while self.step():
[rank3]: ^^^^^^^^^^^
[rank3]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
[rank3]: self.dispatch_table[inst.opcode](self, inst)
[rank3]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2987, in RETURN_VALUE
[rank3]: self._return(inst)
[rank3]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2972, in _return
[rank3]: self.output.compile_subgraph(
[rank3]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1117, in compile_subgraph
[rank3]: self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
[rank3]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1369, in compile_and_call_fx_graph
[rank3]: compiled_fn = self.call_user_compiler(gm)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1416, in call_user_compiler
[rank3]: return self._call_user_compiler(gm)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1465, in _call_user_compiler
[rank3]: raise BackendCompilerFailed(self.compiler_fn, e) from e
[rank3]: torch._dynamo.exc.BackendCompilerFailed: backend='compile_fn' raised:
[rank3]: NotImplementedError: DDPOptimizer backend: Found a higher order op in the graph. This is not supported. Please turn off DDP optimizer using torch._dynamo.config.optimize_ddp=False. Note that this can cause performance degradation because there will be one bucket for the entire Dynamo graph. Please refer to this issue - https://github.com/pytorch/pytorch/issues/104674.
[rank3]: Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
Any idea?