LyCORIS icon indicating copy to clipboard operation
LyCORIS copied to clipboard

GLoRA inference fails with Flux due to weights being in bfloat16 precision

Open mhirki opened this issue 1 year ago • 1 comments

My simple inference script is failing when calling wrapper.merge_to() with Flux Dev as the base model.

2024-09-21 19:27:53|[LyCORIS]-INFO: Loading Modules from state dict...
2024-09-21 19:27:54|[LyCORIS]-INFO: 504 Modules Loaded
Traceback (most recent call last):
  File "/nvme/home/mikaelh/Stable_Diffusion/bghira/output/models.bak_flux_sanna_marin_v0.4_fp8_multires_adan3_glora/inference2.py", line 12, in <module>
    wrapper.merge_to(0.5)
  File "/nvme/home/mikaelh/Stable_Diffusion/bghira/SimpleTuner.latest/.venv/lib/python3.11/site-packages/lycoris/wrapper.py", line 567, in merge_to
    lora.merge_to(weight)
  File "/nvme/home/mikaelh/Stable_Diffusion/bghira/SimpleTuner.latest/.venv/lib/python3.11/site-packages/lycoris/modules/base.py", line 269, in merge_to
    weight, bias = self.get_merged_weight(
                   ^^^^^^^^^^^^^^^^^^^^^^^
  File "/nvme/home/mikaelh/Stable_Diffusion/bghira/SimpleTuner.latest/.venv/lib/python3.11/site-packages/lycoris/modules/glora.py", line 208, in get_merged_weight
    diff_w, _ = self.get_diff_weight(multiplier, shape, device)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nvme/home/mikaelh/Stable_Diffusion/bghira/SimpleTuner.latest/.venv/lib/python3.11/site-packages/lycoris/modules/glora.py", line 202, in get_diff_weight
    weight = self.make_weight(device) * multiplier
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nvme/home/mikaelh/Stable_Diffusion/bghira/SimpleTuner.latest/.venv/lib/python3.11/site-packages/lycoris/modules/glora.py", line 198, in make_weight
    w_wa2 = (orig @ wa1) @ wa2
             ~~~~~^~~~~
RuntimeError: expected m1 and m2 to have the same dtype, but got: c10::BFloat16 != float

orig is in bfloat16 precision while wa1 and wa2 are in float precision. I tried both upcasting orig and downcasting wa1 and wa2 and there was very little difference in the end result. Upcasting to float precision did run much faster on cpu. I'm not sure which way you prefer to solve this so I'm posting this as an issue.

Here's my inference script for reference:

import torch
from diffusers import FluxPipeline
from lycoris import create_lycoris_from_weights

torch.set_num_threads(16)

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)

adapter_id = 'pytorch_lora_weights_fixed.safetensors' # you will have to download this manually
lora_scale = 1
wrapper, _ = create_lycoris_from_weights(lora_scale, adapter_id, pipe.transformer)
wrapper.merge_to(0.5)

pipe.enable_sequential_cpu_offload()

prompt = "sanna marin playing tennis"
generator = torch.Generator().manual_seed(1000)
out = pipe(
    prompt=prompt,
    guidance_scale=3.5,
    height=1280,
    width=832,
    num_inference_steps=20,
    generator=generator
).images[0]
out.save("image.png")

mhirki avatar Sep 21 '24 16:09 mhirki

will implement some type checks You can do this as workaround:

wrapper.apply_to()
wrapper.to(device, dtype)
wrapper.restore()
wrapper.merge_to()

KohakuBlueleaf avatar Sep 21 '24 16:09 KohakuBlueleaf

Should be fixed in 3.1.1.post1

KohakuBlueleaf avatar Dec 09 '24 13:12 KohakuBlueleaf