flux fill fp8 load failed
Describe the bug
cc. https://huggingface.co/AlekseyCalvin/FluxFillDev_fp8_Diffusers/discussions/1
I want to run flux fill with fp8 for faster inference but it failed
Reproduction
from diffusers import FluxTransformer2DModel, FluxFillPipeline
from transformers import T5EncoderModel
import torch
transformer = FluxTransformer2DModel.from_pretrained("AlekseyCalvin/FluxFillDev_fp8_Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
text_encoder_2 = T5EncoderModel.from_pretrained("AlekseyCalvin/FluxFillDev_fp8_Diffusers", subfolder="text_encoder_2", torch_dtype=torch.bfloat16).to("cuda")
pipe = FluxFillPipeline.from_pretrained("black-forest-labs/FLUX.1-Fill-dev", transformer=transformer, text_encoder_2=text_encoder_2, torch_dtype=torch.bfloat16).to("cuda")
or
import torch
from diffusers import FluxTransformer2DModel, FluxPipeline
from transformers import T5EncoderModel, CLIPTextModel
from optimum.quanto import freeze, qfloat8, quantize
bfl_repo = "black-forest-labs/FLUX.1-Fill-dev"
dtype = torch.bfloat16
transformer = FluxTransformer2DModel.from_single_file("/home/ubuntu/black-forest-labs_FLUX.1-Fill-dev_flux1-fill-dev_fp8.safetensors", torch_dtype=dtype)
quantize(transformer, weights=qfloat8)
freeze(transformer)
text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)
pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, text_encoder_2=None, torch_dtype=dtype)
pipe.transformer = transformer
pipe.text_encoder_2 = text_encoder_2
pipe.enable_model_cpu_offload()
prompt = "A cat holding a sign that says hello world"
image = load_image("https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/cup.png")
mask = load_image("https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/cup_mask.png")
image = pipe(
prompt="a white paper cup",
image=image,
mask_image=mask,
height=1632,
width=1232,
guidance_scale=30,
num_inference_steps=50,
max_sequence_length=512,
generator=torch.Generator("cpu").manual_seed(0)
).images[0]
image.save("flux-fp8-dev.png")
Logs
ValueError: Trying to set a tensor of shape torch.Size([3072, 384]) in "weight" (which has shape torch.Size([3072, 64])), this looks incorrect.
System Info
0.32.0.dev0, python 3.10 ,nvidia A100
Who can help?
@sayakpaul @DN6
Looks like AlekseyCalvin/FluxFillDev_fp8_Diffusers has the config for FLUX.1-dev instead of FLUX.1-Fill-dev
Thanks! it worked in upper case.
I finally test fp8 version and original version and I found that whether I change model fp8, it does not run more faster..
But Finally I want to test quantize version but Does it speed would be same as original even I quantize?
@Suprhimp you could try the gguf quantize method that was added very recently
transformer = FluxTransformer2DModel.from_single_file(
"https://huggingface.co/YarvixPA/FLUX.1-Fill-dev-gguf/blob/main/flux1-fill-dev-Q4_0.gguf",
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
torch_dtype=torch.bfloat16,
)
pipe = FluxFillPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Fill-dev",
transformer=transformer,
torch_dtype=torch.bfloat16,
).to(device)
pipe.enable_model_cpu_offload()
mask_image = Image.fromarray((mask_2d * 255).astype(np.uint8))
print(f"Applying Flux Fill with prompt: '{fill_prompt}'")
flux_image = pipe(
prompt=fill_prompt,
image=Image.fromarray(image),
mask_image=mask_image,
height=image.shape[0],
width=image.shape[1],
guidance_scale=30,
num_inference_steps=50,
max_sequence_length=512,
generator=torch.Generator(device.type).manual_seed(0),
).images[0]
print("Saving Flux Fill result...")
flux_image.save("output/flux_fill_output.png")
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
@kryali Thanks! This one works for me.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
@Suprhimp Is it okay to close this issue? Looks like @kryali's solution should give you what you need?
transformer = FluxTransformer2DModel.from_single_file("/FLUX.1-Fill-dev_fp8.safetensors", local_files_only = True,torch_dtype=dtype,config = '/FLUX.1-Fill-dev/transformer/') The output is normal, but the result is wrong
@babyta I'm closing this issue as resolved. Please feel free to open a new one and provide a code example to reproduce the error please.