Traceback (most recent call last):
[rank2]: File "/data/hanrui/CLEAR/distill.py", line 1245, in
[rank2]: main(args)
[rank2]: File "/data/hanrui/CLEAR/distill.py", line 1028, in main
[rank2]: teacher_pred = transformer_teacher(
[rank2]: File "/data/hanrui/conda/pytorch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank2]: return self._call_impl(*args, **kwargs)
[rank2]: File "/data/hanrui/conda/pytorch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank2]: return forward_call(*args, **kwargs)
[rank2]: File "/data/hanrui/conda/pytorch/lib/python3.10/site-packages/diffusers/models/transformers/transformer_flux.py", line 733, in forward
[rank2]: encoder_hidden_states, hidden_states = block(
[rank2]: File "/data/hanrui/conda/pytorch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank2]: return self._call_impl(*args, **kwargs)
[rank2]: File "/data/hanrui/conda/pytorch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank2]: return forward_call(*args, **kwargs)
[rank2]: File "/data/hanrui/conda/pytorch/lib/python3.10/site-packages/diffusers/models/transformers/transformer_flux.py", line 456, in forward
[rank2]: attention_outputs = self.attn(
[rank2]: File "/data/hanrui/conda/pytorch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
[rank2]: return self._call_impl(*args, **kwargs)
[rank2]: File "/data/hanrui/conda/pytorch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
[rank2]: return forward_call(*args, **kwargs)
[rank2]: File "/data/hanrui/conda/pytorch/lib/python3.10/site-packages/diffusers/models/transformers/transformer_flux.py", line 343, in forward
[rank2]: return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
[rank2]: File "/data/hanrui/CLEAR/attention_processor.py", line 84, in call
[rank2]: query = apply_rotary_emb(query, image_rotary_emb)
[rank2]: File "/data/hanrui/conda/pytorch/lib/python3.10/site-packages/diffusers/models/embeddings.py", line 1224, in apply_rotary_emb
[rank2]: out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
[rank2]: RuntimeError: The size of tensor a (4608) must match the size of tensor b (16896) at non-singleton dimension 2
按照代码的流程进行的,但是出现了这个错误
应该是diffusers版本不匹配造成的,可以考虑回退版本至0.31.0或按照新版中FluxPipeline的逻辑修改代码