peft
peft copied to clipboard
[WIP] Add LoRA multihead attention module
First stab at adding LoRA support for nn.MultiheadAttention. See #761.
Todos:
- [x] ~~For now, only works with
_qkv_same_embed_dim=True-- make it work withFalsetoo.~~_qkv_same_embed_dim=Falseis out of scope for this PR and can be added in a later PR if needed. - [x] Show that it works in a real world test: See user feedback on the issue.
- [x] Unit tests
- [ ] ~~Docs~~ Apart from docstrings, I don't think anything else needs to be added
Update: I now also included the out_proj to apply LoRA to.
This is a simple test that I ran successfully with the PR in its current state:
import open_clip
import requests
import torch
from torch import nn
from peft import LoraConfig, get_peft_model
from PIL import Image
from peft.tuners.lora.layer import MultiheadAttention as PeftMha
model, preprocess = open_clip.create_model_from_pretrained('hf-hub:laion/CLIP-ViT-g-14-laion2B-s12B-b42K')
tokenizer = open_clip.get_tokenizer('hf-hub:laion/CLIP-ViT-g-14-laion2B-s12B-b42K')
peft_model = get_peft_model(model, config)
opt = torch.optim.SGD(peft_model.parameters(), 0.1)
print(len([m for m in peft_model.modules() if isinstance(m, PeftMha)])) # 64 PEFT MHA layers
peft_model.print_trainable_parameters() # trainable params: 2,588,672 || all params: 1,055,873,793 || trainable%: 0.24516869508096598
# text encoder
text = tokenizer(["a diagram", "a dog", "a cat"])
text_features = peft_model.encode_text(text)
loss = text_features.sum()
loss.backward()
opt.step()
# image encoder
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
image = preprocess(image).unsqueeze(0)
image_features = model.encode_image(image)
image_features.sum().backward()
opt.step()
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
@younesbelkada Could I address all your concerns?
I pinged the user who wanted to test it on their case. When it comes to docs, I didn't really find a place where we list all supported layers, so no update needed really.
Note: The test test_merge_layers for MHA fails. This is most likely because of an existing bug in how merging is implemented, see PR #1355. Once that is merged, the test should pass.
Just want to bump a bunch of the issues I've mentioned in #761 but specifically the problem with requires_grad reproducable in this repo
just wanted to bump this one because it's really the only way for tuning CLIP models after they are released.
@bghira Do you happen to have a use case where you could test if this PR works and is working well enough speed-wise? I think the implementation could be ready to be merged but ideally we'd have someone with a real use case give it a try.
i do and i may be able to test it. stupid question but is the code example above complete? i dont see the hinge loss function
stupid question but is the code example above complete? i dont see the hinge loss function
You mean the code right at the top? No, it's not complete at all, just a quick test to show that MHA is applied and the backward pass does not fail. This is not proper nor complete training code.
@BenjaminBossan Hi Ben,
Thank you for directing me here; it seems like the exact issue I am looking for. Since this function has not been officially merged into the main branch yet, could you kindly let me know what the config will look like for the multihead LoRA? (peft_model = get_peft_model(model, config)).
I hope to receive some instructions and test this function soon! I'm very much looking forward to it!
Here is current issues I met.
I run my code with
lora_config = LoraConfig(
r=12,
lora_alpha=24,
target_modules=["attn"],
lora_dropout=0.05,
bias="none"
)
I found a few warnings, and the performance degradation was extremely dramatic. I will dive into this issue.
Loading evaluator: Classification No checkpoint found, train from scratch Initialize tensorboard (log_dir=......./tensorboard) /home/isaac/anaconda3/envs/lib/python3.8/site-packages/peft/tuners/tuners_utils.py:711: UserWarning: All adapters are already merged, nothing to do. warnings.warn("All adapters are already merged, nothing to do.") /home/isaac/anaconda3/envs/lib/python3.8/site-packages/peft/tuners/lora/layer.py:439: UserWarning: Already unmerged. Nothing to do. warnings.warn("Already unmerged. Nothing to do.") epoch [1/5] batch [20/204] time 0.177 (0.218) data 0.000 (0.011) loss 0.6667 (0.8238) lr 1.0000e-05 eta 0:03:37 epoch [1/5] batch [40/204] time 0.178 (0.198) data 0.000 (0.006) loss 1.2822 (0.8632) lr 1.0000e-05 eta 0:03:14 epoch [1/5] batch [60/204] time 0.178 (0.192) data 0.000 (0.004) loss 1.2055 (0.7797) lr 1.0000e-05 eta 0:03:03 epoch [1/5] batch [80/204] time 0.178 (0.188) data 0.000 (0.003) loss 0.1426 (0.8225) lr 1.0000e-05 eta 0:02:57 epoch [1/5] batch [100/204] time 0.178 (0.186) data 0.000 (0.002) loss 0.1367 (0.7533) lr 1.0000e-05 eta 0:02:51 epoch [1/5] batch [120/204] time 0.179 (0.185) data 0.000 (0.002) loss 0.1386 (0.7612) lr 1.0000e-05 eta 0:02:46 epoch [1/5] batch [140/204] time 0.179 (0.184) data 0.000 (0.002) loss 0.1560 (0.7837) lr 1.0000e-05 eta 0:02:42 epoch [1/5] batch [160/204] time 0.179 (0.184) data 0.000 (0.002) loss 0.1206 (0.7639) lr 1.0000e-05 eta 0:02:37 epoch [1/5] batch [180/204] time 0.179 (0.183) data 0.000 (0.001) loss 2.1940 (0.7720) lr 1.0000e-05 eta 0:02:33 epoch [1/5] batch [200/204] time 0.178 (0.183) data 0.000 (0.001) loss 0.4894 (0.7809) lr 1.0000e-05 eta 0:02:29 epoch [2/5] batch [20/204] time 0.178 (0.188) data 0.000 (0.009) loss 3.9319 (1.8420) lr 3.5000e-03 eta 0:02:29 epoch [2/5] batch [40/204] time 0.178 (0.183) data 0.000 (0.005) loss 4.0332 (3.0835) lr 3.5000e-03 eta 0:02:22 epoch [2/5] batch [60/204] time 0.178 (0.182) data 0.000 (0.003) loss 4.0524 (3.3984) lr 3.5000e-03 eta 0:02:17 epoch [2/5] batch [80/204] time 0.178 (0.181) data 0.000 (0.002) loss 4.0278 (3.5583) lr 3.5000e-03 eta 0:02:13 epoch [2/5] batch [100/204] time 0.178 (0.181) data 0.000 (0.002) loss 4.0273 (3.6542) lr 3.5000e-03 eta 0:02:09 epoch [2/5] batch [120/204] time 0.179 (0.180) data 0.000 (0.002) loss 4.0250 (3.7172) lr 3.5000e-03 eta 0:02:05 epoch [2/5] batch [140/204] time 0.179 (0.180) data 0.000 (0.001) loss 4.0519 (3.7622) lr 3.5000e-03 eta 0:02:01 epoch [2/5] batch [160/204] time 0.179 (0.180) data 0.000 (0.001) loss 4.0429 (3.7968) lr 3.5000e-03 eta 0:01:58 epoch [2/5] batch [180/204] time 0.179 (0.180) data 0.000 (0.001) loss 4.0290 (3.8228) lr 3.5000e-03 eta 0:01:54 epoch [2/5] batch [200/204] time 0.178 (0.180) data 0.000 (0.001) loss 4.0289 (3.8440) lr 3.5000e-03 eta 0:01:50 epoch [3/5] batch [20/204] time 0.182 (0.189) data 0.000 (0.010) loss 4.0316 (4.0368) lr 3.1658e-03 eta 0:01:51 epoch [3/5] batch [40/204] time 0.179 (0.184) data 0.000 (0.005) loss 4.0338 (4.0353) lr 3.1658e-03 eta 0:01:45 epoch [3/5] batch [60/204] time 0.179 (0.182) data 0.000 (0.003) loss 4.0431 (4.0348) lr 3.1658e-03 eta 0:01:40 epoch [3/5] batch [80/204] time 0.179 (0.181) data 0.000 (0.003) loss 4.0460 (4.0352) lr 3.1658e-03 eta 0:01:36 epoch [3/5] batch [100/204] time 0.179 (0.181) data 0.000 (0.002) loss 4.0375 (4.0352) lr 3.1658e-03 eta 0:01:32 epoch [3/5] batch [120/204] time 0.179 (0.181) data 0.000 (0.002) loss 4.0183 (4.0345) lr 3.1658e-03 eta 0:01:28 epoch [3/5] batch [140/204] time 0.179 (0.181) data 0.000 (0.002) loss 4.0367 (4.0340) lr 3.1658e-03 eta 0:01:25 epoch [3/5] batch [160/204] time 0.187 (0.181) data 0.000 (0.001) loss 4.0395 (4.0336) lr 3.1658e-03 eta 0:01:21 epoch [3/5] batch [180/204] time 0.179 (0.180) data 0.000 (0.001) loss 4.0319 (4.0340) lr 3.1658e-03 eta 0:01:17 epoch [3/5] batch [200/204] time 0.183 (0.181) data 0.000 (0.001) loss 4.0406 (4.0340) lr 3.1658e-03 eta 0:01:14 epoch [4/5] batch [20/204] time 0.184 (0.189) data 0.000 (0.010) loss 4.0245 (4.0325) lr 2.2908e-03 eta 0:01:13 epoch [4/5] batch [40/204] time 0.179 (0.184) data 0.000 (0.005) loss 4.0110 (4.0332) lr 2.2908e-03 eta 0:01:07 epoch [4/5] batch [60/204] time 0.179 (0.183) data 0.000 (0.003) loss 4.0404 (4.0320) lr 2.2908e-03 eta 0:01:03 epoch [4/5] batch [80/204] time 0.179 (0.182) data 0.000 (0.002) loss 4.0287 (4.0318) lr 2.2908e-03 eta 0:00:59 epoch [4/5] batch [100/204] time 0.179 (0.181) data 0.000 (0.002) loss 4.0235 (4.0310) lr 2.2908e-03 eta 0:00:55 epoch [4/5] batch [120/204] time 0.179 (0.181) data 0.000 (0.002) loss 4.0425 (4.0309) lr 2.2908e-03 eta 0:00:52 epoch [4/5] batch [140/204] time 0.179 (0.181) data 0.000 (0.001) loss 4.0313 (4.0315) lr 2.2908e-03 eta 0:00:48 epoch [4/5] batch [160/204] time 0.284 (0.183) data 0.000 (0.001) loss 3.9778 (4.0311) lr 2.2908e-03 eta 0:00:45 epoch [4/5] batch [180/204] time 0.188 (0.184) data 0.000 (0.001) loss 4.0434 (4.0309) lr 2.2908e-03 eta 0:00:42 epoch [4/5] batch [200/204] time 0.179 (0.186) data 0.000 (0.001) loss 4.0240 (4.0313) lr 2.2908e-03 eta 0:00:38 epoch [5/5] batch [20/204] time 0.182 (0.201) data 0.000 (0.010) loss 4.0220 (4.0251) lr 1.2092e-03 eta 0:00:36 epoch [5/5] batch [40/204] time 0.204 (0.197) data 0.000 (0.005) loss 4.1016 (4.0277) lr 1.2092e-03 eta 0:00:32 epoch [5/5] batch [60/204] time 0.180 (0.192) data 0.000 (0.003) loss 4.0113 (4.0270) lr 1.2092e-03 eta 0:00:27 epoch [5/5] batch [80/204] time 0.179 (0.189) data 0.000 (0.003) loss 4.0538 (4.0266) lr 1.2092e-03 eta 0:00:23 epoch [5/5] batch [100/204] time 0.180 (0.187) data 0.000 (0.002) loss 4.0295 (4.0261) lr 1.2092e-03 eta 0:00:19 epoch [5/5] batch [120/204] time 0.179 (0.186) data 0.000 (0.002) loss 3.9695 (4.0243) lr 1.2092e-03 eta 0:00:15 epoch [5/5] batch [140/204] time 0.179 (0.185) data 0.000 (0.002) loss 4.0654 (4.0252) lr 1.2092e-03 eta 0:00:11 epoch [5/5] batch [160/204] time 0.226 (0.186) data 0.000 (0.001) loss 4.0207 (4.0259) lr 1.2092e-03 eta 0:00:08 epoch [5/5] batch [180/204] time 0.179 (0.186) data 0.000 (0.001) loss 4.0290 (4.0258) lr 1.2092e-03 eta 0:00:04 epoch [5/5] batch [200/204] time 0.179 (0.186) data 0.000 (0.001) loss 4.0228 (4.0256) lr 1.2092e-03 eta 0:00:00
@BenjaminBossan Hi Ben,
Thank you for directing me here; it seems like the exact issue I am looking for. Since this function has not been officially merged into the main branch yet, could you kindly let me know what the config will look like for the multihead LoRA? (peft_model = get_peft_model(model, config)).
I hope to receive some instructions and test this function soon! I'm very much looking forward to it!
Here is current issues I met.
I run my code with
lora_config = LoraConfig( r=12, lora_alpha=24, target_modules=["attn"], lora_dropout=0.05, bias="none" )I found a few warnings, and the performance degradation was extremely dramatic. I will dive into this issue.
Loading evaluator: Classification No checkpoint found, train from scratch Initialize tensorboard (log_dir=......./tensorboard) /home/isaac/anaconda3/envs/lib/python3.8/site-packages/peft/tuners/tuners_utils.py:711: UserWarning: All adapters are already merged, nothing to do. warnings.warn("All adapters are already merged, nothing to do.") /home/isaac/anaconda3/envs/lib/python3.8/site-packages/peft/tuners/lora/layer.py:439: UserWarning: Already unmerged. Nothing to do. warnings.warn("Already unmerged. Nothing to do.") epoch [1/5] batch [20/204] time 0.177 (0.218) data 0.000 (0.011) loss 0.6667 (0.8238) lr 1.0000e-05 eta 0:03:37 epoch [1/5] batch [40/204] time 0.178 (0.198) data 0.000 (0.006) loss 1.2822 (0.8632) lr 1.0000e-05 eta 0:03:14 epoch [1/5] batch [60/204] time 0.178 (0.192) data 0.000 (0.004) loss 1.2055 (0.7797) lr 1.0000e-05 eta 0:03:03 epoch [1/5] batch [80/204] time 0.178 (0.188) data 0.000 (0.003) loss 0.1426 (0.8225) lr 1.0000e-05 eta 0:02:57 epoch [1/5] batch [100/204] time 0.178 (0.186) data 0.000 (0.002) loss 0.1367 (0.7533) lr 1.0000e-05 eta 0:02:51 epoch [1/5] batch [120/204] time 0.179 (0.185) data 0.000 (0.002) loss 0.1386 (0.7612) lr 1.0000e-05 eta 0:02:46 epoch [1/5] batch [140/204] time 0.179 (0.184) data 0.000 (0.002) loss 0.1560 (0.7837) lr 1.0000e-05 eta 0:02:42 epoch [1/5] batch [160/204] time 0.179 (0.184) data 0.000 (0.002) loss 0.1206 (0.7639) lr 1.0000e-05 eta 0:02:37 epoch [1/5] batch [180/204] time 0.179 (0.183) data 0.000 (0.001) loss 2.1940 (0.7720) lr 1.0000e-05 eta 0:02:33 epoch [1/5] batch [200/204] time 0.178 (0.183) data 0.000 (0.001) loss 0.4894 (0.7809) lr 1.0000e-05 eta 0:02:29 epoch [2/5] batch [20/204] time 0.178 (0.188) data 0.000 (0.009) loss 3.9319 (1.8420) lr 3.5000e-03 eta 0:02:29 epoch [2/5] batch [40/204] time 0.178 (0.183) data 0.000 (0.005) loss 4.0332 (3.0835) lr 3.5000e-03 eta 0:02:22 epoch [2/5] batch [60/204] time 0.178 (0.182) data 0.000 (0.003) loss 4.0524 (3.3984) lr 3.5000e-03 eta 0:02:17 epoch [2/5] batch [80/204] time 0.178 (0.181) data 0.000 (0.002) loss 4.0278 (3.5583) lr 3.5000e-03 eta 0:02:13 epoch [2/5] batch [100/204] time 0.178 (0.181) data 0.000 (0.002) loss 4.0273 (3.6542) lr 3.5000e-03 eta 0:02:09 epoch [2/5] batch [120/204] time 0.179 (0.180) data 0.000 (0.002) loss 4.0250 (3.7172) lr 3.5000e-03 eta 0:02:05 epoch [2/5] batch [140/204] time 0.179 (0.180) data 0.000 (0.001) loss 4.0519 (3.7622) lr 3.5000e-03 eta 0:02:01 epoch [2/5] batch [160/204] time 0.179 (0.180) data 0.000 (0.001) loss 4.0429 (3.7968) lr 3.5000e-03 eta 0:01:58 epoch [2/5] batch [180/204] time 0.179 (0.180) data 0.000 (0.001) loss 4.0290 (3.8228) lr 3.5000e-03 eta 0:01:54 epoch [2/5] batch [200/204] time 0.178 (0.180) data 0.000 (0.001) loss 4.0289 (3.8440) lr 3.5000e-03 eta 0:01:50 epoch [3/5] batch [20/204] time 0.182 (0.189) data 0.000 (0.010) loss 4.0316 (4.0368) lr 3.1658e-03 eta 0:01:51 epoch [3/5] batch [40/204] time 0.179 (0.184) data 0.000 (0.005) loss 4.0338 (4.0353) lr 3.1658e-03 eta 0:01:45 epoch [3/5] batch [60/204] time 0.179 (0.182) data 0.000 (0.003) loss 4.0431 (4.0348) lr 3.1658e-03 eta 0:01:40 epoch [3/5] batch [80/204] time 0.179 (0.181) data 0.000 (0.003) loss 4.0460 (4.0352) lr 3.1658e-03 eta 0:01:36 epoch [3/5] batch [100/204] time 0.179 (0.181) data 0.000 (0.002) loss 4.0375 (4.0352) lr 3.1658e-03 eta 0:01:32 epoch [3/5] batch [120/204] time 0.179 (0.181) data 0.000 (0.002) loss 4.0183 (4.0345) lr 3.1658e-03 eta 0:01:28 epoch [3/5] batch [140/204] time 0.179 (0.181) data 0.000 (0.002) loss 4.0367 (4.0340) lr 3.1658e-03 eta 0:01:25 epoch [3/5] batch [160/204] time 0.187 (0.181) data 0.000 (0.001) loss 4.0395 (4.0336) lr 3.1658e-03 eta 0:01:21 epoch [3/5] batch [180/204] time 0.179 (0.180) data 0.000 (0.001) loss 4.0319 (4.0340) lr 3.1658e-03 eta 0:01:17 epoch [3/5] batch [200/204] time 0.183 (0.181) data 0.000 (0.001) loss 4.0406 (4.0340) lr 3.1658e-03 eta 0:01:14 epoch [4/5] batch [20/204] time 0.184 (0.189) data 0.000 (0.010) loss 4.0245 (4.0325) lr 2.2908e-03 eta 0:01:13 epoch [4/5] batch [40/204] time 0.179 (0.184) data 0.000 (0.005) loss 4.0110 (4.0332) lr 2.2908e-03 eta 0:01:07 epoch [4/5] batch [60/204] time 0.179 (0.183) data 0.000 (0.003) loss 4.0404 (4.0320) lr 2.2908e-03 eta 0:01:03 epoch [4/5] batch [80/204] time 0.179 (0.182) data 0.000 (0.002) loss 4.0287 (4.0318) lr 2.2908e-03 eta 0:00:59 epoch [4/5] batch [100/204] time 0.179 (0.181) data 0.000 (0.002) loss 4.0235 (4.0310) lr 2.2908e-03 eta 0:00:55 epoch [4/5] batch [120/204] time 0.179 (0.181) data 0.000 (0.002) loss 4.0425 (4.0309) lr 2.2908e-03 eta 0:00:52 epoch [4/5] batch [140/204] time 0.179 (0.181) data 0.000 (0.001) loss 4.0313 (4.0315) lr 2.2908e-03 eta 0:00:48 epoch [4/5] batch [160/204] time 0.284 (0.183) data 0.000 (0.001) loss 3.9778 (4.0311) lr 2.2908e-03 eta 0:00:45 epoch [4/5] batch [180/204] time 0.188 (0.184) data 0.000 (0.001) loss 4.0434 (4.0309) lr 2.2908e-03 eta 0:00:42 epoch [4/5] batch [200/204] time 0.179 (0.186) data 0.000 (0.001) loss 4.0240 (4.0313) lr 2.2908e-03 eta 0:00:38 epoch [5/5] batch [20/204] time 0.182 (0.201) data 0.000 (0.010) loss 4.0220 (4.0251) lr 1.2092e-03 eta 0:00:36 epoch [5/5] batch [40/204] time 0.204 (0.197) data 0.000 (0.005) loss 4.1016 (4.0277) lr 1.2092e-03 eta 0:00:32 epoch [5/5] batch [60/204] time 0.180 (0.192) data 0.000 (0.003) loss 4.0113 (4.0270) lr 1.2092e-03 eta 0:00:27 epoch [5/5] batch [80/204] time 0.179 (0.189) data 0.000 (0.003) loss 4.0538 (4.0266) lr 1.2092e-03 eta 0:00:23 epoch [5/5] batch [100/204] time 0.180 (0.187) data 0.000 (0.002) loss 4.0295 (4.0261) lr 1.2092e-03 eta 0:00:19 epoch [5/5] batch [120/204] time 0.179 (0.186) data 0.000 (0.002) loss 3.9695 (4.0243) lr 1.2092e-03 eta 0:00:15 epoch [5/5] batch [140/204] time 0.179 (0.185) data 0.000 (0.002) loss 4.0654 (4.0252) lr 1.2092e-03 eta 0:00:11 epoch [5/5] batch [160/204] time 0.226 (0.186) data 0.000 (0.001) loss 4.0207 (4.0259) lr 1.2092e-03 eta 0:00:08 epoch [5/5] batch [180/204] time 0.179 (0.186) data 0.000 (0.001) loss 4.0290 (4.0258) lr 1.2092e-03 eta 0:00:04 epoch [5/5] batch [200/204] time 0.179 (0.186) data 0.000 (0.001) loss 4.0228 (4.0256) lr 1.2092e-03 eta 0:00:00
Not sure if it's a problem with the lora_alpha parameter, since it works fine when lora_alpha=1. However, choosing 2*rank seems to destroy the model's performance. Perhaps bigger alpha is not fit for standard CLIP model.
Not sure if it's a problem with the lora_alpha parameter, since it works fine when lora_alpha=1. However, choosing 2*rank seems to destroy the model's performance. Perhaps bigger alpha is not fit for standard CLIP model.
Honestly, I don't know what lora_alpha value works well with MHA. Thanks for testing it out. Is the performance on par with your expectation for lora_alpha=1?
About these warnings:
/home/isaac/anaconda3/envs/lib/python3.8/site-packages/peft/tuners/tuners_utils.py:711: UserWarning: All adapters are already merged, nothing to do. warnings.warn("All adapters are already merged, nothing to do.") /home/isaac/anaconda3/envs/lib/python3.8/site-packages/peft/tuners/lora/layer.py:439: UserWarning: Already unmerged. Nothing to do.
Those are definitely strange, as it means the script tried to merge and unmerge some layers, which normally shouldn't happen. You should check your training code for suspicious lines related to merging.
Not sure if it's a problem with the lora_alpha parameter, since it works fine when lora_alpha=1. However, choosing 2*rank seems to destroy the model's performance. Perhaps bigger alpha is not fit for standard CLIP model.
Honestly, I don't know what
lora_alphavalue works well with MHA. Thanks for testing it out. Is the performance on par with your expectation forlora_alpha=1?About these warnings:
/home/isaac/anaconda3/envs/lib/python3.8/site-packages/peft/tuners/tuners_utils.py:711: UserWarning: All adapters are already merged, nothing to do. warnings.warn("All adapters are already merged, nothing to do.") /home/isaac/anaconda3/envs/lib/python3.8/site-packages/peft/tuners/lora/layer.py:439: UserWarning: Already unmerged. Nothing to do.
Those are definitely strange, as it means the script tried to merge and unmerge some layers, which normally shouldn't happen. You should check your training code for suspicious lines related to merging.
Hi Ben,
Thank you again for your prompt reply!
For lora_alpha=1, it indeed increases model performance in classification, so I think maybe MHA only fits with small scaling. And I did additional research here; if you've heard about MultiLoRA (paper link), the scaling factor, even trained from zero, might cause problems. I guess MHA is extremely sensitive to scaling factors.
As for the warnings, I'm unsure if they are related to my training code, since I use the Dassl framework for the training process. But it seems it does not affect how LoRA improves performance.
Really appreciate this work!
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
not stale
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
@BenjaminBossan it seems like out_proj is being merged twice during MultiheadAttention.forward - once in self.merge (https://github.com/BenjaminBossan/peft/blob/7e91712075d97bed014e64ea3388de4c34f97b29/src/peft/tuners/lora/layer.py#L1203) and then again in forward (https://github.com/BenjaminBossan/peft/blob/7e91712075d97bed014e64ea3388de4c34f97b29/src/peft/tuners/lora/layer.py#L1292). doesn't seem to be problematic but it is spamming warnings during inference.
Thanks for reporting this @damian0815. I had forgotten about this PR. The issue should now be fixed, as well as a few other issues. If you find that LoRA for MHA is helpful and works for your use case, please let me know.
@BenjaminBossan definitely useful, yes.
I'm running into an issue, though, trying to generate validation metrics on the model during training - for legacy reasons we're saving the model's state dict using restore_state_dict = {k: v.detach().cpu() for k, v in model.state_dict().items()}, running inference on the model for evaluation, and then restoring the state dict using model.load_state_dict(restore_state_dict). However, the restoring fails with
Unexpected key(s) in state_dict: "base_model.model.text.transformer.resblocks.8.attn.base_layer.in_proj_weight", "base_model.model.text.transformer.resblocks.8.attn.base_layer.out_proj.base_layer.weight", "base_model.model.text.transformer.resblocks.9.attn.base_layer.in_proj_weight", "base_model.model.text.transformer.resblocks.9.attn.base_layer.out_proj.base_layer.weight", "base_model.model.text.transformer.resblocks.10.attn.base_layer.in_proj_weight", "base_model.model.text.transformer.resblocks.10.attn.base_layer.out_proj.base_layer.weight", "base_model.model.text.transformer.resblocks.11.attn.base_layer.in_proj_weight", "base_model.model.text.transformer.resblocks.11.attn.base_layer.out_proj.base_layer.weight".
it's only happening after calling .forward() on the model (restoring the state dict before that works fine). moreover if i put a breakpoint on the line where the failing restore happens and execute set(model.state_dict().keys()).symmetric_difference(restore_state_dict.keys()) in the debugger, the result is an empty set().
definitely useful, yes.
That's good to hear. Hopefully this PR can be merged some day so that we can have MHA support in PEFT proper, it's just that multihead attention is implemented in a way that makes applying LoRA very difficult and requires some hacks. To wit:
However, the restoring fails with
I think this is related to this part:
https://github.com/huggingface/peft/pull/1324/files#diff-24a141c266b7b714ae8fcc470f31bc283f7b0f5a671bbf6d5f092741fc374104R1290-R1294
Could you check if calling _restore_weights manually would solve the error?
The code would be something along these lines:
for module in model.modules():
if isinstance(module, peft.tuners.lora.MultiheadAttention):
module._restore_weights()
yes, that solved it - thanks (but i had to use peft.tuners.lora.layers.MultiheadAttention for the fully qualified module class)
Great, thanks for confirming @damian0815, and sorry for the wrong path.
I tried to create a unit test based on the description you provided, I think I could reproduce your error. Could you quickly check if the test captures your situation?
@pytest.mark.xfail(strict=True)
def test_mha_load_init_model_first():
# this test fails as it currently requires a workaround to pass, see test below
# https://github.com/huggingface/peft/pull/1324#issuecomment-2252473980
inputs = torch.rand(10, 10, 10)
model = ModelMha()
config = LoraConfig(target_modules=["mha"], init_lora_weights=False)
model = get_peft_model(model, config).eval()
restore_state_dict = {k: v.detach().cpu() for k, v in model.state_dict().items()}
del model
model = ModelMha()
# inferencing with PEFT model first is necessary to trigger the error in load_state_dict
model = get_peft_model(model, config)
model(inputs)
model.load_state_dict(restore_state_dict)
def test_mha_load_init_model_first_with_workaround():
import peft
inputs = torch.rand(10, 10, 10)
model = ModelMha()
config = LoraConfig(target_modules=["mha"], init_lora_weights=False)
model = get_peft_model(model, config).eval()
with torch.inference_mode():
output_before = model(inputs)
restore_state_dict = {k: v.detach().cpu() for k, v in model.state_dict().items()}
del model
model = ModelMha()
model = get_peft_model(model, config)
model(inputs)
# workaround, see test above
for module in model.modules():
if isinstance(module, peft.tuners.lora.layer.MultiheadAttention):
module._restore_weights()
model.load_state_dict(restore_state_dict)
with torch.inference_mode():
output_after = model(inputs)
assert torch.allclose(output_before, output_after)
Unfortunately, I could not find a way to hook into load_state_dict to automatically call _restore_weights, since load_state_dict is not recursive, so the PEFT MultiheadAttention is never directly invoked :( I hope this is enough of an edge case that I can ignore it for now.
looks about right - we're not deleting/reloading the model in-between though, simply messing with the weights (doing a blend with the base model -- which is in fact disabled when LoRA training is active, but the save/restore logic runs anyway) and then restoring the weights by loading the restore_state_dict in place.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Check out this pull request on ![]()
See visual diffs & provide feedback on Jupyter Notebooks.
Powered by ReviewNB