Lora PISSA init: not support gpt2
System Info
peft 0.13.0 transformers 4.44.2 torch 2.4.0 Python 3.12.4
Who can help?
No response
Information
- [ ] The official example scripts
- [X] My own modified scripts
Tasks
- [ ] An officially supported task in the
examplesfolder - [X] My own task or dataset (give details below)
Reproduction
import os os.environ["WANDB_DISABLED"] = "true" from transformers import GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments from peft import LoraConfig, get_peft_model from torch.utils.data import Dataset from torchsummary import summary import torch from datasets import load_dataset, config from trl import SFTTrainer
model_name = "gpt2"
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
lora_config = LoraConfig(
r=8,
lora_alpha=8,
lora_dropout=0,
target_modules=["attn.c_attn"],
init_lora_weights="pissa",
fan_in_fan_out=True,
bias="none"
)
model = get_peft_model(model, lora_config)
dataset = load_dataset("imdb", split="train[:1%]")
trainer = SFTTrainer( model=model, train_dataset=dataset, dataset_text_field="text", max_seq_length=128, tokenizer=tokenizer, )
Expected behavior
Hello, I found that current pissa init code forget to consider the fin_in_fin_out parameter to transpose the matrix weight, which makes gpt2 training failed because of dimension mismatch. I have fixed the bug with the following code:
def pissa_init(self, adapter_name, init_lora_weights):
weight = self.get_base_layer().weight
dtype = weight.dtype
if dtype not in [torch.float32, torch.float16, torch.bfloat16]:
raise TypeError(
"Please initialize PiSSA under float32, float16, or bfloat16. "
"Subsequently, re-quantize the residual model to help minimize quantization errors."
)
weight = transpose(weight.to(torch.float32),self.fan_in_fan_out)
if init_lora_weights == "pissa":
# USV^T = W <-> VSU^T = W^T, where W^T = weight.data in R^{out_channel, in_channel},
V, S, Uh = torch.linalg.svd(weight.data, full_matrices=False)
Vr = V[:, : self.r[adapter_name]]
Sr = S[: self.r[adapter_name]]
Sr /= self.scaling[adapter_name]
Uhr = Uh[: self.r[adapter_name]]
elif len(init_lora_weights.split("_niter_")) == 2:
Vr, Sr, Ur = svd_lowrank(
weight.data, self.r[adapter_name], niter=int(init_lora_weights.split("_niter_")[-1])
)
Sr /= self.scaling[adapter_name]
Uhr = Ur.t()
else:
raise ValueError(
f"init_lora_weights should be 'pissa' or 'pissa_niter_[number of iters]', got {init_lora_weights} instead."
)
lora_A = torch.diag(torch.sqrt(Sr)) @ Uhr
lora_B = Vr @ torch.diag(torch.sqrt(Sr))
self.lora_A[adapter_name].weight.data = lora_A
self.lora_B[adapter_name].weight.data = lora_B
weight = weight.data - self.scaling[adapter_name] * lora_B @ lora_A
weight = transpose(weight.to(dtype),self.fan_in_fan_out)
self.get_base_layer().weight.data = weight
Thanks for reporting this bug and providing a potential solution. Would you be interested in creating a PR with your fix?
Thanks for reporting this bug and providing a potential solution. Would you be interested in creating a PR with your fix?
Thanks! I'd be happy to submit a PR with my fix.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
This issue is resolved via #2104.