peft icon indicating copy to clipboard operation
peft copied to clipboard

Lora PISSA init: not support gpt2

Open suyang160 opened this issue 1 year ago • 2 comments

System Info

peft 0.13.0 transformers 4.44.2 torch 2.4.0 Python 3.12.4

Who can help?

No response

Information

  • [ ] The official example scripts
  • [X] My own modified scripts

Tasks

  • [ ] An officially supported task in the examples folder
  • [X] My own task or dataset (give details below)

Reproduction

import os os.environ["WANDB_DISABLED"] = "true" from transformers import GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments from peft import LoraConfig, get_peft_model from torch.utils.data import Dataset from torchsummary import summary import torch from datasets import load_dataset, config from trl import SFTTrainer

model_name = "gpt2"
model = GPT2LMHeadModel.from_pretrained(model_name) tokenizer = GPT2Tokenizer.from_pretrained(model_name)

tokenizer.pad_token = tokenizer.eos_token

lora_config = LoraConfig( r=8,
lora_alpha=8,
lora_dropout=0, target_modules=["attn.c_attn"],
init_lora_weights="pissa", fan_in_fan_out=True, bias="none" )

model = get_peft_model(model, lora_config)

dataset = load_dataset("imdb", split="train[:1%]")

trainer = SFTTrainer( model=model, train_dataset=dataset, dataset_text_field="text", max_seq_length=128, tokenizer=tokenizer, )

Expected behavior

Hello, I found that current pissa init code forget to consider the fin_in_fin_out parameter to transpose the matrix weight, which makes gpt2 training failed because of dimension mismatch. I have fixed the bug with the following code:

def pissa_init(self, adapter_name, init_lora_weights):
    weight = self.get_base_layer().weight
    dtype = weight.dtype
    if dtype not in [torch.float32, torch.float16, torch.bfloat16]:
        raise TypeError(
            "Please initialize PiSSA under float32, float16, or bfloat16. "
            "Subsequently, re-quantize the residual model to help minimize quantization errors."
        )
    weight = transpose(weight.to(torch.float32),self.fan_in_fan_out)
    if init_lora_weights == "pissa":
        # USV^T = W <-> VSU^T = W^T, where W^T = weight.data in R^{out_channel, in_channel},
        V, S, Uh = torch.linalg.svd(weight.data, full_matrices=False)
        Vr = V[:, : self.r[adapter_name]]
        Sr = S[: self.r[adapter_name]]
        Sr /= self.scaling[adapter_name]
        Uhr = Uh[: self.r[adapter_name]]
    elif len(init_lora_weights.split("_niter_")) == 2:
        Vr, Sr, Ur = svd_lowrank(
            weight.data, self.r[adapter_name], niter=int(init_lora_weights.split("_niter_")[-1])
        )
        Sr /= self.scaling[adapter_name]
        Uhr = Ur.t()
    else:
        raise ValueError(
            f"init_lora_weights should be 'pissa' or 'pissa_niter_[number of iters]', got {init_lora_weights} instead."
        )

    lora_A = torch.diag(torch.sqrt(Sr)) @ Uhr
    lora_B = Vr @ torch.diag(torch.sqrt(Sr))
    self.lora_A[adapter_name].weight.data = lora_A
    self.lora_B[adapter_name].weight.data = lora_B
    weight = weight.data - self.scaling[adapter_name] * lora_B @ lora_A
    weight = transpose(weight.to(dtype),self.fan_in_fan_out)
    self.get_base_layer().weight.data = weight

suyang160 avatar Sep 26 '24 13:09 suyang160

Thanks for reporting this bug and providing a potential solution. Would you be interested in creating a PR with your fix?

BenjaminBossan avatar Sep 26 '24 14:09 BenjaminBossan

Thanks for reporting this bug and providing a potential solution. Would you be interested in creating a PR with your fix?

Thanks! I'd be happy to submit a PR with my fix.

suyang160 avatar Sep 26 '24 14:09 suyang160

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

github-actions[bot] avatar Oct 26 '24 15:10 github-actions[bot]

This issue is resolved via #2104.

BenjaminBossan avatar Oct 28 '24 10:10 BenjaminBossan