peft icon indicating copy to clipboard operation
peft copied to clipboard

Request for integrate sine-LoRA

Open yipingji opened this issue 8 months ago • 7 comments

Feature request

This request proposes integrating sine-LoRA, a simple yet effective drop-in method that can be applied to low-rank matrices and the paper has been accepted in ICLR 2025. The project page is available at https://samy-ji.github.io/sine_activated_PEL/, and the paper can be found at https://arxiv.org/abs/2403.19243.

Motivation

Low-rank decomposition has emerged as a vital tool for enhancing parameter efficiency in neural network architectures, gaining traction across diverse applications in machine learning. These techniques significantly lower the number of parameters, striking a balance between compactness and performance. However, a common challenge has been the compromise between parameter efficiency and the accuracy of the model, where reduced parameters often lead to diminished accuracy compared to their full-rank counterparts. Our method introduce sine activation function on the low-rank matrices for rank boosting without extra parameters and we provide mathematical proof in the paper.

Our method is easily implemented and only needs one line change!

In https://github.com/samy-ji/Sine-Low-Rank/blob/main/llm/peft/src/peft/tuners/sinlora.py Line#379 demonstrates how sine function, frequency and scaling factor apply in our method.

Also, I attach it below that how our method can be applied on LoRA and DoRA.

## LoRA forward pass 

def forward(self, x: torch.Tensor):
    base_result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)

    result += ((self.lora_dropout(x.to(self.lora_A.weight.dtype)) @ self.lora_A.weight.T) @ self.lora_B.weight.T) * self.scaling
    return result

## Sine LoRA forward pass
def forward(self, x: torch.Tensor):
    base_result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
    dropout_x = self.lora_dropout(x)
    
    result += ((self.lora_dropout(x.to(self.lora_A.weight.dtype))) @ torch.sin(self.freq * self.lora_A.weight.T @ self.lora_B.weight.T))/self.s * self.scaling
    return result

## DoRA forward pass
def forward(self, x: torch.Tensor):
    base_result = F.linear(x, transpose(self.weight, self.fan_in_fan_out))
    dropout_x = self.lora_dropout(x)

    new_weight_v = self.weight + (self.lora_B.weight @ self.lora_A.weight) * self.scaling
    norm_scale = self.weight_m_wdecomp.weight.view(-1) / (torch.linalg.norm(new_weight_v,dim=1)).detach()
    result = base_result + (norm_scale-1) * (F.linear(dropout_x, transpose(self.weight, self.fan_in_fan_out)))
    result += ( norm_scale * (self.lora_B(self.lora_A(dropout_x.to(self.lora_A.weight.dtype))))) * self.scaling
    if not self.bias is None:
        result += self.bias.view(1, -1).expand_as(result)
    return result

## Sine DoRA forward pass
def forward(self, x: torch.Tensor):
    base_result = F.linear(x, transpose(self.weight, self.fan_in_fan_out))
    dropout_x = self.lora_dropout(x)

    new_weight_v = self.weight + torch.sin(self.freq*(self.lora_B.weight @ self.lora_A.weight))/self.s * self.scaling
    norm_scale = self.weight_m_wdecomp.weight.view(-1) / (torch.linalg.norm(new_weight_v,dim=1)).detach()
    result = base_result + (norm_scale-1) * (F.linear(dropout_x, transpose(self.weight, self.fan_in_fan_out)))
    result += (norm_scale * torch.sin(self.freq*(self.lora_B(self.lora_A(dropout_x.to(self.lora_A.weight.dtype))))/self.s)) * self.scaling
    if not self.bias is None:
        result += self.bias.view(1, -1).expand_as(result)
    return result

We evaluate our method on LoRA and DoRA using Llama 3 8B on commonsense benchmarks.

Image

Your contribution

We implement sine-LoRA and sine-DoRA code in https://github.com/samy-ji/Sine-Low-Rank/tree/main/llm

yipingji avatar Mar 18 '25 15:03 yipingji

Thanks for opening the issue and suggesting to add this new method to PEFT. @githubnemo and I have already looked at sine-LoRA and we're currently discussing what the best way is to integrate it. It could be a standalone method, but that would mean a lot of copy-pasting of LoRA code with only minimal changes. Or it could be directly integrated into LoRA, same as DoRA is right now, but that makes the LoRA code more complex and there can be complicated interactions between different LoRA variants. We will get back to you, hopefully very soon, once we came up with a decision.

BenjaminBossan avatar Mar 20 '25 15:03 BenjaminBossan

Thanks for the update! Please let me know if you need any help.

Best regards, Yiping

On Fri, Mar 21, 2025 at 2:26 AM Benjamin Bossan @.***> wrote:

Thanks for opening the issue and suggesting to add this new method to PEFT. @githubnemo https://github.com/githubnemo and I have already looked at sine-LoRA and we're currently discussing what the best way is to integrate it. It could be a standalone method, but that would mean a lot of copy-pasting of LoRA code with only minimal changes. Or it could be directly integrated into LoRA, same as DoRA is right now, but that makes the LoRA code more complex and there can be complicated interactions between different LoRA variants. We will get back to you, hopefully very soon, once we came up with a decision.

— Reply to this email directly, view it on GitHub https://github.com/huggingface/peft/issues/2434#issuecomment-2740943926, or unsubscribe https://github.com/notifications/unsubscribe-auth/BI7NMWES5Z2FPYHME75HBCL2VLQLTAVCNFSM6AAAAABZIRBVFCVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDONBQHE2DGOJSGY . You are receiving this because you authored the thread.Message ID: @.***> [image: BenjaminBossan]BenjaminBossan left a comment (huggingface/peft#2434) https://github.com/huggingface/peft/issues/2434#issuecomment-2740943926

Thanks for opening the issue and suggesting to add this new method to PEFT. @githubnemo https://github.com/githubnemo and I have already looked at sine-LoRA and we're currently discussing what the best way is to integrate it. It could be a standalone method, but that would mean a lot of copy-pasting of LoRA code with only minimal changes. Or it could be directly integrated into LoRA, same as DoRA is right now, but that makes the LoRA code more complex and there can be complicated interactions between different LoRA variants. We will get back to you, hopefully very soon, once we came up with a decision.

— Reply to this email directly, view it on GitHub https://github.com/huggingface/peft/issues/2434#issuecomment-2740943926, or unsubscribe https://github.com/notifications/unsubscribe-auth/BI7NMWES5Z2FPYHME75HBCL2VLQLTAVCNFSM6AAAAABZIRBVFCVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDONBQHE2DGOJSGY . You are receiving this because you authored the thread.Message ID: @.***>

yipingji avatar Mar 20 '25 15:03 yipingji

Hey @samy-ji,

with https://github.com/huggingface/peft/pull/2443 merged we can proceed to implement sine-LoRA and sine-DoRA as LoRA variations. Since DoRA is now already implemented as a variation, it would lend itself to build the sine-DoRA implementation on top of that. For the sine-LoRA implementation you'd need to implement a new variation but presumably a rather simple one.

Would you be interested in taking this on?

You could use the changes in https://github.com/huggingface/peft/pull/2443 as a rough guide. Of course, we are here to answer any questions that may arise :)

githubnemo avatar Mar 25 '25 12:03 githubnemo

Hi @githubnemo,

I agree that it would be best to implement it as a LoRA variation as you described.

I'm not sure on the best way to implement this (I've had a look at #2443 but I'm not confident how to make the changes). Could you provide some advice on the specific changes, or would it be possible to check the implementation? :)

yipingji avatar Mar 26 '25 00:03 yipingji

No worries, in general just write as you think is correct and submit a draft PR so we can discuss the solution together.

In general I think the solution would look something like this:

class SineLoraLinearVariant(LoraVariant):
    @staticmethod
    def init(module: Linear, adapter_name: str) -> None:
        module.freq = <somehow get the freq, we're currently discussing how>

    @staticmethod
    def forward(module: Linear, active_adapter: str, x: torch.Tensor, result: torch.Tensor) -> torch.Tensor:
        <the way you would implement sine lora forward>

and then you'd have to update the resolve_lora_variant in the specific layer types (e.g., in class Linear of tuners/lora/layers.py) to react on a new config variable (e.g., use_sinelora).

githubnemo avatar Mar 26 '25 13:03 githubnemo

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

github-actions[bot] avatar Apr 19 '25 15:04 github-actions[bot]

not stale

BenjaminBossan avatar Apr 22 '25 09:04 BenjaminBossan