flux icon indicating copy to clipboard operation
flux copied to clipboard

ValueError: Trying to set a tensor of shape torch.Size([3072, 64]) in "weight" (which has shape torch.Size([3072, 384])), this looks incorrect.

Open haozhuoyuan opened this issue 9 months ago • 2 comments

import torch
from diffusers import FluxFillPipeline
from diffusers.utils import load_image

image = load_image("https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/cup.png")
mask = load_image("https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/cup_mask.png")

pipe = FluxFillPipeline.from_pretrained("black-forest-labs/FLUX.1-Fill-dev", torch_dtype=torch.bfloat16).to("cuda")
image = pipe(
    prompt="a white paper cup",
    image=image,
    mask_image=mask,
    height=1632,
    width=1232,
    guidance_scale=30,
    num_inference_steps=50,
    max_sequence_length=512,
    generator=torch.Generator("cpu").manual_seed(0)
).images[0]
image.save(f"flux-fill-dev.png")

Traceback (most recent call last): File "/workdir/flux/flask_flux_fill_demo.py", line 18, in pipe = FluxFillPipeline.from_pretrained("/workdir/flux/checkpoints/flux_dev_fill", torch_dtype=torch.bfloat16).to("cuda")

File "/opt/conda/lib/python3.11/site-packages/huggingface_hub/utils/_validators.py", line 114, in _inner_fn return fn(*args, **kwargs)

File "/opt/conda/lib/python3.11/site-packages/diffusers/pipelines/pipeline_utils.py", line 924, in from_pretrained loaded_sub_model = load_sub_model(

File "/opt/conda/lib/python3.11/site-packages/diffusers/pipelines/pipeline_loading_utils.py", line 725, in load_sub_model loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/conda/lib/python3.11/site-packages/huggingface_hub/utils/_validators.py", line 114, in _inner_fn return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "/opt/conda/lib/python3.11/site-packages/diffusers/models/modeling_utils.py", line 932, in from_pretrained accelerate.load_checkpoint_and_dispatch( File "/opt/conda/lib/python3.11/site-packages/accelerate/big_modeling.py", line 613, in load_checkpoint_and_dispatch load_checkpoint_in_model( File "/opt/conda/lib/python3.11/site-packages/accelerate/utils/modeling.py", line 1821, in load_checkpoint_in_model set_module_tensor_to_device( File "/opt/conda/lib/python3.11/site-packages/accelerate/utils/modeling.py", line 373, in set_module_tensor_to_device raise ValueError( ValueError: Trying to set a tensor of shape torch.Size([3072, 64]) in "weight" (which has shape torch.Size([3072, 384])), this looks incorrect.

Which big guy help me look at what problem, thank you, is it the environment, I have updated the diffusers and accelerate?

haozhuoyuan avatar Feb 21 '25 09:02 haozhuoyuan

I am also having the same problem. It occurs from the fact that in pipeline_flux_fill.py you will see in step 7 of the denoising loop:

            noise_pred = self.transformer(
                hidden_states=torch.cat((latents, masked_image_latents), dim=2),
                timestep=timestep / 1000,
                guidance=guidance,
                pooled_projections=pooled_prompt_embeds,
                encoder_hidden_states=prompt_embeds,
                txt_ids=text_ids,
                img_ids=latent_image_ids,
                joint_attention_kwargs=self.joint_attention_kwargs,
                return_dict=False,
            )[0]

They are passing a hidden_states which has a dimension of 384 in dim 2. But in the transformer_flux.py file they set it to 64. self.x_embedder = nn.Linear(self.config.in_channels, self.inner_dim) (should be around line 274).

So in the forward method, when they are trying to embed the hidden_states with this line of code: hidden_states = self.x_embedder(hidden_states) It gives a size mismatch because the embedder is expecting an input of size 64 in dim 2 whereas the hidden_states has 384. This is what causing the problem.

siyamsajeebkhan avatar Mar 18 '25 14:03 siyamsajeebkhan

I face the same issue using the load_checkpoint_and_dispatch() method . How did you fix?

MaxHeuillet avatar Jun 17 '25 20:06 MaxHeuillet