peft icon indicating copy to clipboard operation
peft copied to clipboard

High loss when init with `AdaLora`

Open zwhe99 opened this issue 1 year ago • 3 comments

System Info

torch==2.0.1 transformers==4.40.2 flash-attn==2.3.3 peft==0.11.1 accelerate==0.28.0 deepspeed==0.14.0 bitsandbytes==0.43.0 datasets==2.18.0 trl==0.8.6

python==3.10.10

Who can help?

@BenjaminBossan

Information

  • [ ] The official example scripts
  • [X] My own modified scripts

Tasks

  • [ ] An officially supported task in the examples folder
  • [X] My own task or dataset (give details below)

Reproduction

run.py:

import peft
from peft import AdaLoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="cuda:0")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", verbose=False)
text = """Fine-tuning large pretrained models is often prohibitively costly due to their scale. Parameter-Efficient Fine-Tuning (PEFT) methods enable efficient adaptation of large pretrained models to various downstream applications by only fine-tuning a small number of (extra) model parameters instead of all the model's parameters. This significantly decreases the computational and storage costs. Recent state-of-the-art PEFT techniques achieve performance comparable to fully fine-tuned models.

PEFT is integrated with Transformers for easy model training and inference, Diffusers for conveniently managing different adapters, and Accelerate for distributed training and inference for really big models."""
inputs = tokenizer(text, return_tensors="pt").to("cuda:0")

# print loss for original model
outputs = model(**inputs, labels=inputs["input_ids"])
loss = outputs.loss
print(loss.item())

# print loss for adalora
adalora_config = AdaLoraConfig(
    target_r=8,
    init_r=12,
    target_modules=["q_proj", "v_proj"],
    lora_alpha=8,
    task_type=peft.utils.peft_types.TaskType.CAUSAL_LM
)

model = get_peft_model(model, adalora_config, adapter_name="adalora")
outputs = model(**inputs, labels=inputs["input_ids"])
loss = outputs.loss
print(loss.item())

output:

2.1954474449157715
3.3121755123138428

Expected behavior

The two losses should be the same. The diagonal elements of the AdaLoRA initialization should be all 0s, and therefore should not change the results. paper

zwhe99 avatar May 24 '24 06:05 zwhe99

not sure why you closed the issue but I can replicate it

Thanks for bringing this to our attention. Indeed, the lora_E parameter should be initialized as zeros but it is not. I did some archaeology and indeed this used to be the case but was later changed in this PR:

  • before: nn.init.zeros_(self.lora_E[adapter_name])
  • after: nn.init.normal_(self.lora_E[adapter_name], mean=0.0, std=0.02)

I highly assume this was done by accident, as the line probably was copy-pasted from lora_A or lora_B, where this is correct. Ping @younesbelkada in case he can remember.

It would probably be best to revert to the correct initialization, even though that means that we change the behavior of AdaLoRA compared to PEFT versions 0.4-0.11.

BenjaminBossan avatar May 24 '24 10:05 BenjaminBossan

Thanks for your response. I closed this issue since the difference in losses is due to the orthogonal regularization term.

zwhe99 avatar May 24 '24 10:05 zwhe99

Thanks for your response. I closed this issue since the difference in losses is due to the orthogonal regularization term.

I see, thanks for explaining. Let's still keep this open, as I think the line I mentioned was changed by accident and this should be restored to how the paper describes the initialization.

BenjaminBossan avatar May 27 '24 09:05 BenjaminBossan

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

github-actions[bot] avatar Jun 23 '24 15:06 github-actions[bot]