diffusers
diffusers copied to clipboard
DoRA loading does not load all keys from the state_dict
Describe the bug
When loading a DoRA model from a kohya state_dict some keys are silently skipped in the state_dict.
DoRA loading was added in https://github.com/huggingface/diffusers/pull/7371.
This feature has not been released yet, so I am encountering this issue on main (commit: 6133d98ff70eafad7b9f65da50a450a965d1957f)
Reproduction
In this script, I try to load the same test DoRA that was used in the original DoRA PR (https://github.com/huggingface/diffusers/pull/7371).
from diffusers import DiffusionPipeline
from safetensors.torch import load_file
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
variant="fp16",
).to("cuda")
pipe.load_lora_weights("streamize/test-dora")
# Load state_dict directly from local path.
state_dict = load_file(
"/home/ryan/.cache/huggingface/hub/models--streamize--test-dora/snapshots/2c73f1cccb75b19c0b597f7ebadb10624966cd3f/pytorch_lora_weights.safetensors"
)
key = "lora_te1_text_model_encoder_layers_0_mlp_fc1.dora_scale"
print(f"State dict value at key: {key}")
print("-----")
val = state_dict[key]
print(f"val.shape: {val.shape}")
print(f"val[0, :5]: {val[0, :5]}")
print(f"\nModel tensor at key: {key}")
print("-----")
val = pipe.text_encoder.text_model.encoder.layers[0].mlp.fc1.lora_magnitude_vector["default_0"]
print(f"val.shape: {val.shape}")
print(f"val[:5]: {val[:5]}")
Output:
State dict value at key: lora_te1_text_model_encoder_layers_0_mlp_fc1.dora_scale
-----
val.shape: torch.Size([1, 768])
val[0, :5]: tensor([ 0.0029, -0.0030, 0.0007, -0.0010, -0.0026], dtype=torch.float16)
Model tensor at key: lora_te1_text_model_encoder_layers_0_mlp_fc1.dora_scale
-----
val.shape: torch.Size([3072])
val[:5]: tensor([0.4485, 0.4538, 0.4752, 0.4901, 0.4194], device='cuda:0',
grad_fn=<SliceBackward0>)
I am using "lora_te1_text_model_encoder_layers_0_mlp_fc1.dora_scale" as an example, but the same behaviour is observed for many keys. The state_dict value does not get injected into the model. In fact, it's shape is not even compatible with the target tensor where I'd expect it to be injected.
From the digging I have done so far, I currently suspect 2 issues:
- The conversion of keys from kohya format to peft format is not working correctly.
- The shape mismatch might be caused by column-wise vs row-wise weight norm calculations in the two DoRA implementations.
To understand the problem better, I recommend setting a breakpoint here: https://github.com/huggingface/peft/blob/26726bf1ddee6ca75ed4e1bfd292094526707a78/src/peft/utils/save_and_load.py#L249
Inspecting the state before and after load_state_dict() makes it easy to see which state_dict keys that diffusers is trying to inject, and which ones are not being applied.
Beware of this issue with load_state_dict()'s handling of unexpected_keys: https://github.com/pytorch/pytorch/issues/123510. This threw me off when I was debugging.
Logs
No response
System Info
Testing with diffusers commit: 6133d98ff70eafad7b9f65da50a450a965d1957f
diffusersversion: 0.28.0.dev0- Platform: Linux-5.15.0-82-generic-x86_64-with-glibc2.35
- Python version: 3.10.12
- PyTorch version (GPU?): 2.1.2+cu121 (True)
- Huggingface_hub version: 0.20.2
- Transformers version: 4.39.3
- Accelerate version: 0.23.0
- xFormers version: 0.0.23.post1
- Using GPU in script?: Yes
- Using distributed or parallel set-up in script?: No
Who can help?
@sayakpaul
Cc: @BenjaminBossan here since it seems like the problem is stemming from the state dict injection step.
Inspecting the state before and after load_state_dict() makes it easy to see which state_dict keys that diffusers is trying to inject, and which ones are not being applied.
Which load_state_dict() function are you referring to here?
Which
load_state_dict()function are you referring to here?
This one:
To understand the problem better, I recommend setting a breakpoint here: https://github.com/huggingface/peft/blob/26726bf1ddee6ca75ed4e1bfd292094526707a78/src/peft/utils/save_and_load.py#L249
@yiyixuxu as per understanding, the issue has roots in peft. Hence I cc'd @BenjaminBossan.
@yiyixuxu as per understanding, the issue has roots in
peft. Hence I cc'd @BenjaminBossan.
I think the main issue is the kohya key conversion in diffusers - it produces keys that do not exist.
I think the main issue is the kohya key conversion in diffusers - it produces keys that do not exist.
Do not exist where? Could you give an example?
I think the main issue is the kohya key conversion in diffusers - it produces keys that do not exist.
Do not exist where? Could you give an example?
Using the example from the reproduction script, "lora_te1_text_model_encoder_layers_0_mlp_fc1.dora_scale" from the kohya state_dict gets converted by diffusers to "text_model.encoder.layers.0.mlp.fc1.lora_magnitude_vector.default_0.down.weight". There is no such module in the peft model, so it silently gets skipped.
I am just using this key as an example, but the same is true for many keys. I have not gone to the effort of checking them all.
I see many occurrences of lora_magnitud_evector here: https://github.com/search?q=repo%3Ahuggingface%2Fpeft%20lora_magnitude_vector&type=code. Perhaps @BenjaminBossan could help clarify this.
I am just using this key as an example, but the same is true for many keys. I have not gone to the effort of checking them all.
Unique examples will be appreciated.
It's hard for me to understand what is going on here.
From the PEFT side of things, we don't really do anything special with the DoRA parameters, so treating them in the same fashion as the other LoRA parameters should be correct. What's making this difficult is that the adapters were trained with another LoRA/DoRA implementation (LyCORIS I assume), not with PEFT, so they could have some differences there that make it difficult to load their weights onto a PEFT model. We don't have any control over that, we don't even know if this is stable over time (short of tracking all the code changes there).
To get to the bottom of this, we would need to understand what differentiates this checkpoint from the previous ones that seemed to work correctly with #7371 added.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.