CLEAR icon indicating copy to clipboard operation
CLEAR copied to clipboard

Error when training the proposed method

Open syf-fgnb opened this issue 10 months ago • 4 comments

I follow the steps shown in the repo to set up the conda environment, and then execute distill.sh on A100. It raises the following error:

Traceback (most recent call last): 
[rank0]:   File "/root/CLEAR/distill.py", line 1242, in <module>
[rank0]:     main(args)
[rank0]:   File "/root/CLEAR/distill.py", line 1025, in main
[rank0]:     teacher_pred = transformer_teacher(
[rank0]:                    ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/diffusers/models/transformers/transformer_flux.py", line 522, in forward
[rank0]:     encoder_hidden_states, hidden_states = block(
[rank0]:                                            ^^^^^^
[rank0]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/diffusers/models/transformers/transformer_flux.py", line 180, in forward
[rank0]:     attention_outputs = self.attn(
[rank0]:                         ^^^^^^^^^^
[rank0]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/diffusers/models/attention_processor.py", line 588, in forward
[rank0]:     return self.processor(
[rank0]:            ^^^^^^^^^^^^^^^
[rank0]:   File "/root/CLEAR/attention_processor.py", line 84, in __call__
[rank0]:     query = apply_rotary_emb(query, image_rotary_emb)
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/diffusers/models/embeddings.py", line 1208, in apply_rotary_emb
[rank0]:     out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
[rank0]:            ~~~~~~~~~~^~~~~
[rank0]: RuntimeError: The size of tensor a (4608) must match the size of tensor b (16896) at non-singleton dimension 2

Here is my environment info:


# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                        main
_openmp_mutex             5.1                       1_gnu
accelerate                1.3.0                    pypi_0    pypi
annotated-types           0.7.0                    pypi_0    pypi
bzip2                     1.0.8                h5eee18b_6
ca-certificates           2024.12.31           h06a4308_0
certifi                   2025.1.31                pypi_0    pypi
charset-normalizer        3.4.1                    pypi_0    pypi
click                     8.1.8                    pypi_0    pypi
contourpy                 1.3.1                    pypi_0    pypi
cycler                    0.12.1                   pypi_0    pypi
deepspeed                 0.16.3                   pypi_0    pypi
diffusers                 0.32.2                   pypi_0    pypi
docker-pycreds            0.4.0                    pypi_0    pypi
einops                    0.8.1                    pypi_0    pypi
expat                     2.6.4                h6a678d5_0
filelock                  3.17.0                   pypi_0    pypi
fonttools                 4.56.0                   pypi_0    pypi
fsspec                    2025.2.0                 pypi_0    pypi
gitdb                     4.0.12                   pypi_0    pypi
gitpython                 3.1.44                   pypi_0    pypi
hjson                     3.1.0                    pypi_0    pypi
huggingface-hub           0.28.1                   pypi_0    pypi
idna                      3.10                     pypi_0    pypi
importlib-metadata        8.6.1                    pypi_0    pypi
jinja2                    3.1.5                    pypi_0    pypi
kiwisolver                1.4.8                    pypi_0    pypi
ld_impl_linux-64          2.40                 h12ee557_0
libffi                    3.4.4                h6a678d5_1
libgcc-ng                 11.2.0               h1234567_1
libgomp                   11.2.0               h1234567_1
libstdcxx-ng              11.2.0               h1234567_1
libuuid                   1.41.5               h5eee18b_0
markupsafe                3.0.2                    pypi_0    pypi
matplotlib                3.10.0                   pypi_0    pypi
mpmath                    1.3.0                    pypi_0    pypi
msgpack                   1.1.0                    pypi_0    pypi
ncurses                   6.4                  h6a678d5_0
networkx                  3.4.2                    pypi_0    pypi
ninja                     1.11.1.3                 pypi_0    pypi
numpy                     2.1.1                    pypi_0    pypi
nvidia-cublas-cu12        12.4.5.8                 pypi_0    pypi
nvidia-cuda-cupti-cu12    12.4.127                 pypi_0    pypi
nvidia-cuda-nvrtc-cu12    12.4.127                 pypi_0    pypi
nvidia-cuda-runtime-cu12  12.4.127                 pypi_0    pypi
nvidia-cudnn-cu12         9.1.0.70                 pypi_0    pypi
nvidia-cufft-cu12         11.2.1.3                 pypi_0    pypi
nvidia-curand-cu12        10.3.5.147               pypi_0    pypi
nvidia-cusolver-cu12      11.6.1.9                 pypi_0    pypi
nvidia-cusparse-cu12      12.3.1.170               pypi_0    pypi
nvidia-cusparselt-cu12    0.6.2                    pypi_0    pypi
nvidia-nccl-cu12          2.21.5                   pypi_0    pypi
nvidia-nvjitlink-cu12     12.4.127                 pypi_0    pypi
nvidia-nvtx-cu12          12.4.127                 pypi_0    pypi
opencv-python             4.11.0.86                pypi_0    pypi
opencv-python-headless    4.11.0.86                pypi_0    pypi
openssl                   3.0.15               h5eee18b_0
packaging                 24.2                     pypi_0    pypi
pandas                    2.2.3                    pypi_0    pypi
pillow                    11.1.0                   pypi_0    pypi
pip                       25.0            py312h06a4308_0
platformdirs              4.3.6                    pypi_0    pypi
prodigyopt                1.1.2                    pypi_0    pypi
protobuf                  5.29.3                   pypi_0    pypi
psutil                    7.0.0                    pypi_0    pypi
py-cpuinfo                9.0.0                    pypi_0    pypi
pydantic                  2.10.6                   pypi_0    pypi
pydantic-core             2.27.2                   pypi_0    pypi
pyparsing                 3.2.1                    pypi_0    pypi
python                    3.12.9               h5148396_0
python-dateutil           2.9.0.post0              pypi_0    pypi
pytz                      2025.1                   pypi_0    pypi
pyyaml                    6.0.2                    pypi_0    pypi
readline                  8.2                  h5eee18b_0
regex                     2024.11.6                pypi_0    pypi
requests                  2.32.3                   pypi_0    pypi
safetensors               0.5.2                    pypi_0    pypi
scipy                     1.15.1                   pypi_0    pypi
seaborn                   0.13.2                   pypi_0    pypi
sentencepiece             0.2.0                    pypi_0    pypi
sentry-sdk                2.21.0                   pypi_0    pypi
setproctitle              1.3.4                    pypi_0    pypi
setuptools                75.8.0          py312h06a4308_0
six                       1.17.0                   pypi_0    pypi
smmap                     5.0.2                    pypi_0    pypi
sqlite                    3.45.3               h5eee18b_0
sympy                     1.13.1                   pypi_0    pypi
tk                        8.6.14               h39e8969_0
tokenizers                0.21.0                   pypi_0    pypi
torch                     2.6.0                    pypi_0    pypi
torchvision               0.21.0                   pypi_0    pypi
tqdm                      4.67.1                   pypi_0    pypi
transformers              4.48.3                   pypi_0    pypi
triton                    3.2.0                    pypi_0    pypi
typing-extensions         4.12.2                   pypi_0    pypi
tzdata                    2025.1                   pypi_0    pypi
ultralytics               8.3.75                   pypi_0    pypi
ultralytics-thop          2.0.14                   pypi_0    pypi
urllib3                   2.3.0                    pypi_0    pypi
wandb                     0.19.6                   pypi_0    pypi
wheel                     0.45.1          py312h06a4308_0
xz                        5.6.4                h5eee18b_1
zipp                      3.21.0                   pypi_0    pypi
zlib                      1.2.13               h5eee18b_1

syf-fgnb avatar Feb 16 '25 08:02 syf-fgnb

Hi,

I speculate the problem is on the diffusers' version. We use 0.31.0 in our experiments.

Huage001 avatar Feb 16 '25 11:02 Huage001

Hi @Huage001

Thanks for the quick reply. I downgrade diffusers according to your suggestion but it raises another error:

Traceback (most recent call last):
[rank0]:   File "/root/xxx/CLEAR/distill.py", line 1242, in <module>
[rank0]:     main(args)
[rank0]:   File "/root/xxx/CLEAR/distill.py", line 1076, in main
[rank0]:     accelerator.backward(loss)
[rank0]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/accelerate/accelerator.py", line 2238, in backward
[rank0]:     self.deepspeed_engine_wrapped.backward(loss, **kwargs)
[rank0]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/accelerate/utils/deepspeed.py", line 261, in backward
[rank0]:     self.engine.backward(loss, **kwargs)
[rank0]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
[rank0]:     ret_val = func(*args, **kwargs)
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/deepspeed/runtime/engine.py", line 2053, in backward
[rank0]:     self.optimizer.backward(loss, retain_graph=retain_graph)
[rank0]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 2062, in backward
[rank0]:     self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
[rank0]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/deepspeed/runtime/fp16/loss_scaler.py", line 63, in backward
[rank0]:     scaled_loss.backward(retain_graph=retain_graph)
[rank0]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_tensor.py", line 626, in backward
[rank0]:     torch.autograd.backward(
[rank0]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/autograd/__init__.py", line 347, in backward
[rank0]:     _engine_run_backward(
[rank0]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/autograd/graph.py", line 823, in _engine_run_backward
[rank0]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/autograd/function.py", line 307, in apply
[rank0]:     return user_fn(self, *args)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1710, in backward
[rank0]:     return impl_fn()
[rank0]:            ^^^^^^^^^
[rank0]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1700, in impl_fn
[rank0]:     out = CompiledFunction._backward_impl(ctx, all_args)
[rank0]:           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2037, in _backward_impl
[rank0]:     CompiledFunction.compiled_bw = aot_config.bw_compiler(
[rank0]:                                    ^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 489, in __call__
[rank0]:     return self.compiler_fn(gm, example_inputs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/backends/common.py", line 54, in _wrapped_bw_compiler
[rank0]:     return disable(disable(bw_compiler_fn)(*args, **kwargs))
[rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_utils_internal.py", line 95, in wrapper_function
[rank0]:     return function(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1808, in bw_compiler
[rank0]:     return inner_compile(
[rank0]:            ^^^^^^^^^^^^^^
[rank0]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 569, in compile_fx_inner
[rank0]:     return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/repro/after_aot.py", line 102, in debug_wrapper
[rank0]:     inner_compiled_fn = compiler_fn(gm, example_inputs)
[rank0]:                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 675, in _compile_fx_inner
[rank0]:     mb_compiled_graph = fx_codegen_and_compile(
[rank0]:                         ^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1129, in fx_codegen_and_compile
[rank0]:     return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 979, in codegen_and_compile
[rank0]:     graph.run(*example_inputs)
[rank0]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_inductor/graph.py", line 855, in run
[rank0]:     return super().run(*args)
[rank0]:            ^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/fx/interpreter.py", line 167, in run
[rank0]:     self.env[node] = self.run_node(node)
[rank0]:                      ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_inductor/graph.py", line 1496, in run_node
[rank0]:     result = super().run_node(n)
[rank0]:              ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/fx/interpreter.py", line 230, in run_node
[rank0]:     return getattr(self, n.op)(n.target, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_inductor/graph.py", line 1143, in call_function
[rank0]:     raise LoweringException(e, target, args, kwargs).with_traceback(
[rank0]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_inductor/graph.py", line 1133, in call_function
[rank0]:     out = lowerings[target](*args, **kwargs)  # type: ignore[index]
[rank0]:           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_inductor/lowering.py", line 409, in wrapped
[rank0]:     out = decomp_fn(*args, **kwargs)
[rank0]:           ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_inductor/kernel/flex_attention.py", line 2361, in flex_attention_backward
[rank0]:     broadcasted_grad_key = autotune_select_algorithm(
[rank0]:                            ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_inductor/select_algorithm.py", line 1909, in autotune_select_algorithm
[rank0]:     return _ALGORITHM_SELECTOR_CACHE(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_inductor/select_algorithm.py", line 1379, in __call__
[rank0]:     raise NoValidChoicesError(
[rank0]: torch._inductor.exc.LoweringException: NoValidChoicesError: No choices to select, please consider adding ATEN into max_autotune_gemm_backends config (defined in torch/_inductor/config.py) to allow at least one choice.
[rank0]:   target: flex_attention_backward
[rank0]:   args[0]: TensorBox(StorageBox(
[rank0]:     InputBuffer(name='primals_1', layout=FixedLayout('cuda:0', torch.bfloat16, size=[2, 24, 4608, 128], stride=[14155776, 128, 3072, 1]))
[rank0]:   ))
[rank0]:   args[1]: TensorBox(StorageBox(
[rank0]:     InputBuffer(name='primals_2', layout=FixedLayout('cuda:0', torch.bfloat16, size=[2, 24, 4608, 128], stride=[14155776, 128, 3072, 1]))
[rank0]:   ))
[rank0]:   args[2]: TensorBox(StorageBox(
[rank0]:     InputBuffer(name='primals_3', layout=FixedLayout('cuda:0', torch.bfloat16, size=[2, 24, 4608, 128], stride=[14155776, 128, 3072, 1]))
[rank0]:   ))
[rank0]:   args[3]: TensorBox(StorageBox(
[rank0]:     InputBuffer(name='getitem', layout=FixedLayout('cuda:0', torch.bfloat16, size=[2, 24, 4608, 128], stride=[14155776, 128, 3072, 1]))
[rank0]:   ))
[rank0]:   args[4]: TensorBox(StorageBox(
[rank0]:     DonatedBuffer(name='getitem_1', layout=FixedLayout('cuda:0', torch.float32, size=[2, 24, 4608], stride=[110592, 4608, 1]))
[rank0]:   ))
[rank0]:   args[5]: TensorBox(StorageBox(
[rank0]:     InputBuffer(name='tangents_1', layout=FixedLayout('cuda:0', torch.bfloat16, size=[2, 24, 4608, 128], stride=[14155776, 589824, 128, 1]))
[rank0]:   ))
[rank0]:   args[6]: TensorBox(StorageBox(
[rank0]:     Pointwise(
[rank0]:       'cuda',
[rank0]:       torch.float32,
[rank0]:       def inner_fn(index):
[rank0]:           i0, i1, i2 = index
[rank0]:           tmp0 = ops.constant(0, torch.float32)
[rank0]:           return tmp0
[rank0]:       ,
[rank0]:       ranges=[2, 24, 4608],
[rank0]:       origin_node=full_default,
[rank0]:       origins=OrderedSet([full_default])
[rank0]:     )
[rank0]:   ))
[rank0]:   args[7]: Subgraph(name='fw_graph0', graph_module=<lambda>(), graph=None)
[rank0]:   args[8]: Subgraph(name='joint_graph0', graph_module=<lambda>(), graph=None)
[rank0]:   args[9]: (4608, 4608, TensorBox(StorageBox(
[rank0]:     InputBuffer(name='primals_5', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 36], stride=[36, 36, 1]))
[rank0]:   )), TensorBox(StorageBox(
[rank0]:     InputBuffer(name='primals_4', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 36, 36], stride=[1296, 1296, 36, 1]))
[rank0]:   )), TensorBox(StorageBox(
[rank0]:     InputBuffer(name='primals_6', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 36], stride=[36, 36, 1]))
[rank0]:   )), TensorBox(StorageBox(
[rank0]:     InputBuffer(name='primals_7', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 36, 36], stride=[1296, 1296, 36, 1]))
[rank0]:   )), TensorBox(StorageBox(
[rank0]:     InputBuffer(name='primals_8', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 36], stride=[36, 36, 1]))
[rank0]:   )), TensorBox(StorageBox(
[rank0]:     InputBuffer(name='primals_9', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 36, 36], stride=[1296, 1296, 36, 1]))
[rank0]:   )), TensorBox(StorageBox(
[rank0]:     InputBuffer(name='primals_10', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 36], stride=[36, 36, 1]))
[rank0]:   )), TensorBox(StorageBox(
[rank0]:     InputBuffer(name='primals_11', layout=FixedLayout('cuda:0', torch.int32, size=[1, 1, 36, 36], stride=[1296, 1296, 36, 1]))
[rank0]:   )), 128, 128, Subgraph(name='mask_graph0', graph_module=<lambda>(), graph=None))
[rank0]:   args[10]: 0.08838834764831845
[rank0]:   args[11]: {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'OUTPUT_LOGSUMEXP': True}
[rank0]:   args[12]: ()
[rank0]:   args[13]: ()

syf-fgnb avatar Feb 17 '25 02:02 syf-fgnb

A little strange.

From the error message, the problem is on the flex attention. Could you please try torch==2.5.0 because we experiment on that version. Maybe 2.6.0 has some incompatible updates.

Huage001 avatar Feb 17 '25 06:02 Huage001

Thanks, it seems that the above issue is indeed caused by torch 2.6.0. But after I switch to diffusers==0.31.0 and torch==2.5.0, it raises the following error when I run distill.sh:

[rank3]:   File "/root/xxx/projects/GSPN/t2i/advanced/attention_processor.py", line 344, in __call__
[rank3]:     hidden_states = self.flex_attn(query, key, value, scale=attention_scale)
[rank3]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
[rank3]:     return fn(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1263, in __call__
[rank3]:     return hijacked_callback(
[rank3]:            ^^^^^^^^^^^^^^^^^^
[rank3]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1064, in __call__
[rank3]:     result = self._inner_convert(
[rank3]:              ^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 526, in __call__
[rank3]:     return _compile(
[rank3]:            ^^^^^^^^^
[rank3]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 924, in _compile
[rank3]:     guarded_code = compile_inner(code, one_graph, hooks, transform)
[rank3]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 666, in compile_inner
[rank3]:     return _compile_inner(code, one_graph, hooks, transform)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_utils_internal.py", line 87, in wrapper_function
[rank3]:     return function(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 699, in _compile_inner
[rank3]:     out_code = transform_code_object(code, transform)
[rank3]:                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
[rank3]:     transformations(instructions, code_options)
[rank3]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 219, in _fn
[rank3]:     return fn(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 634, in transform
[rank3]:     tracer.run()
[rank3]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2796, in run
[rank3]:     super().run()
[rank3]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
[rank3]:     while self.step():
[rank3]:           ^^^^^^^^^^^
[rank3]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
[rank3]:     self.dispatch_table[inst.opcode](self, inst)
[rank3]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2987, in RETURN_VALUE
[rank3]:     self._return(inst)
[rank3]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2972, in _return
[rank3]:     self.output.compile_subgraph(
[rank3]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1117, in compile_subgraph
[rank3]:     self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
[rank3]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1369, in compile_and_call_fx_graph
[rank3]:     compiled_fn = self.call_user_compiler(gm)
[rank3]:                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1416, in call_user_compiler
[rank3]:     return self._call_user_compiler(gm)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/root/xxx/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1465, in _call_user_compiler
[rank3]:     raise BackendCompilerFailed(self.compiler_fn, e) from e
[rank3]: torch._dynamo.exc.BackendCompilerFailed: backend='compile_fn' raised:
[rank3]: NotImplementedError: DDPOptimizer backend: Found a higher order op in the graph. This is not supported. Please turn off DDP optimizer using torch._dynamo.config.optimize_ddp=False. Note that this can cause performance degradation because there will be one bucket for the entire Dynamo graph. Please refer to this issue - https://github.com/pytorch/pytorch/issues/104674.

[rank3]: Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

Any idea?

syf-fgnb avatar Feb 18 '25 10:02 syf-fgnb