litgpt icon indicating copy to clipboard operation
litgpt copied to clipboard

Adding DoRA (Weight-Decomposed Low-Rank Adaptation) to improve LoRA

Open rasbt opened this issue 6 months ago • 18 comments

The new DoRA (https://arxiv.org/abs/2402.09353) method is a super promising improvement of LoRA.

Screenshot 2024-02-18 at 1 41 03 PM

If there's interest, we can add it to Lit-GPT. Perhaps as an arg setting for the existing LoRA scripts.

The authors didn't release the code yet, but I have a working from-scratch implementation here: https://github.com/rasbt/dora-from-scratch

rasbt avatar Feb 18 '24 19:02 rasbt

These names ... I guess the next will be Nora 😆.


Thanks @rasbt It looks like it doesn't require a lot of changes, so I think it worth adding. The only thing, just to make sure that this is not a "snake oil", I would like to see a lot of benchmarks to come along with the PR.


I have only one question related to the code. In your blog you said that:

DoRA method aims to apply LoRA only to the directional component, V

while in the implementation:

class LinearWithDoRAMerged(nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(
            linear.in_features, linear.out_features, rank, alpha
        )
        
        self.m = nn.Parameter(
            self.linear.weight.norm(p=2, dim=0, keepdim=True))

    def forward(self, x):
        lora = self.lora.A @ self.lora.B
        combined_weight = self.linear.weight + self.lora.alpha*lora.T
        column_norm = combined_weight.norm(p=2, dim=0, keepdim=True)
        V = combined_weight / column_norm
        new_weight = self.m * V
        return F.linear(x, new_weight, self.linear.bias)

you apply LoRA to the pretrained weights, then transform the result into a directional vector $V$ and apply a magnitude vector $m$ to it.

P.S. Haven't yet read the paper, so maybe I'm missing something.

Andrei-Aksionov avatar Feb 19 '24 11:02 Andrei-Aksionov

Thanks for the feedback!

Regarding your question, yeah they way they motivate DoRA and the way it's implemented is a bit misleading. So in their Eq 1 they formulate the pretrained weight decomposition as follows:

$$W_0 = m \frac{V}{||V||_c} $$

And the DoRA weight update via Eq 2 is:

$$W^{\prime}=m \frac{V+BA}{||V+BA||_c}$$ (where A and B are the LoRA matrices)

So, in Eq 2 they actually set $V$ = $W_0$ when they implement LoRA:

$$W^{\prime}={m} \frac{V+\Delta V}{||V+\Delta V||_c}={m} \frac{W_0+{B A}}{\left||W_0+{B A}\right||_c}$$

The code should be correct wrt to Eq 2. It's perhaps easier to see if I change the variable names:

class LinearWithDoRAMerged(nn.Module):

    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(
            linear.in_features, linear.out_features, rank, alpha
        )
        self.m = nn.Parameter(
            self.linear.weight.norm(p=2, dim=0, keepdim=True))


  def forward(self, x):
      lora = self.lora.A @ self.lora.B
      numerator = self.linear.weight + self.lora.alpha*lora.T
      denominator = numerator.norm(p=2, dim=0, keepdim=True)
      directional_component = numerator / denominator
      new_weight = self.m * directional_component
      return F.linear(x, new_weight, self.linear.bias)

The only thing, just to make sure that this is not a "snake oil", I would like to see a lot of benchmarks to come along with the PR.

Sure, I'd would run this on a 7B model of course and compare LoRA and DoRA results. So far, I only tested it on DistilBERT with HF, and it indeed worked actually a tad better than LoRA based on the model's improved classification accuracy after finetuning.

rasbt avatar Feb 19 '24 13:02 rasbt

Thanks for the explanation 🤗. Now I see why it's implemented in that way.

Cool, looks like we settled.

Andrei-Aksionov avatar Feb 19 '24 15:02 Andrei-Aksionov

Hi @rasbt I tried to finetune distil bert base for IMDB dataset using Lora and Dora. Using various checkpoints i plotted the magnitude and directional difference. For LoRA it seems to be inline with what's mentioned in the research paper. But I get similar visualization for DoRA as well. Not sure if I'm missing something.

def get_decomposed_dora_weights(layer):
    lora_weights = layer.lora.A @ layer.lora.B
    numerator = layer.linear.weight + layer.lora.alpha*lora_weights.T
    denominator = numerator.norm(p=2, dim=0, keepdim=True)
    directional_component = numerator / denominator
    return [layer.m, directional_component]

def merged_lora_weights(layer):
    lora_weights = layer.lora.A @ layer.lora.B
    combined_weights = layer.linear.weight + layer.lora.alpha*lora_weights.T
    return combined_weights

def delta_magnitude(pt_weights, ft_weights):
    layer_wise_delta_m = {}
    for i in range(1, n_layers):
        a = ft_weights.get(f"query_layer_{i}")[0]
        b = pt_weights.get(f"query_layer_{i}")[0]
        k = b.shape[1]
        d_m = torch.sum(abs(a - b)) / k
        layer_wise_delta_m[f"layer_{i}"] = round(d_m.item(), 6)
    return layer_wise_delta_m

def delta_direction(pt_weights, ft_weights):
    layer_wise_delta_d = {}
    for i in range(1, n_layers):
        a = ft_weights.get(f"query_layer_{i}")[1]
        b = pt_weights.get(f"query_layer_{i}")[1]
        k = b.shape[0]
        sim = torch.nn.functional.cosine_similarity(a.t(), b.t(), dim=1)
        d_m = torch.sum(1 - sim) / k
        layer_wise_delta_d[f"layer_{i}"] = round(d_m.item(), 6)
    return layer_wise_delta_d

scatter_plot update_plot_lora

shreyassks avatar Feb 26 '24 14:02 shreyassks

Whoa, thanks for doing this! This is unexpected, I agree. Perhaps it could be an artifact of the relatively small DistilBERT model, but there could also be something else going on. I'll keep that in mind when I come back to this PR and the Lit-GPT integration. Thanks for doing that.

Also considering how LoRA is currently implemented in Lit-GPT, the following alternative to the LinearWithDoRAMerged I posted earlier may be preferred:

class LinearWithLoRA(nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(
            linear.in_features, linear.out_features, rank, alpha
        )

    def forward(self, x):
        return self.linear(x) + self.lora(x)

     
class LinearWithDoRA(nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(linear.in_features, linear.out_features, rank, alpha)
        self.m = nn.Parameter(torch.ones(1, linear.out_features))

    def forward(self, x):
        linear_output = self.linear(x)
        lora_output = self.lora(x)
        lora_output_norm = lora_output / (lora_output.norm(p=2, dim=1, keepdim=True) + 1e-9)
        dora_modification = self.m * lora_output_norm
        return linear_output + dora_modification

rasbt avatar Feb 26 '24 15:02 rasbt

Hi, Thanks for providing the updated code. dora_output doesn't seem to capture dora_modification. Am I missing something here?

shreyassks avatar Feb 26 '24 15:02 shreyassks

Ah yes, that was incorrect. I updated it and think it should be correct now (need to think about it a bit more before actually implementing).

rasbt avatar Feb 26 '24 16:02 rasbt

Sure. I'll try to incorporate these changes and see if it comes out as expected.

shreyassks avatar Feb 26 '24 16:02 shreyassks

Here is the updated visualization for DoRA. Now it looks as expected. But one more observation, any layer for all checkpoints has very similar magnitude and directional differences (all checkpoints for a layer are grouped). As per research paper, all layers for a checkpoint are grouped together.

scatter_plot (1)

shreyassks avatar Feb 27 '24 05:02 shreyassks

Thanks for the update. I am currently wrapping up another project but hope to get back to that soon and integrate it into Lit-GPT for experiments on other LLMs like Llama and Gemma. Btw the plot above is for the updated code? I agree that the slope looks better but like you said the grouping seems weird. I need to carefully double-check that I didn't have a typo in that code.

rasbt avatar Feb 27 '24 15:02 rasbt

Yes, it's for the updated code. I'll run few more experiments with norm at dim=0 as well. Let me know if you there's any update from your end. Will be waiting eagerly

shreyassks avatar Feb 27 '24 15:02 shreyassks

Hi @rasbt , great work! I was wondering if there is a way to experiment with your DORA implementation in LIT without having to go through the entire LIT codebase and make extensive adjustments to the configuration. Is it possible to import it externally and try it out? I know the pull request is going to be merged, but I'm asking because I also want to try other implementations of LoRA that are not being added to LIT but have independent implementations, like the ones on your GitHub. What would be a quick workaround to make it work with Llama or Mistral-based models?

monk1337 avatar Mar 05 '24 18:03 monk1337

Thanks for your interest in this. I think it would require a few adjustments to use it. We would have to make those carefully to ensure everything works correctly. I hope to find some time for that in the next few weeks.

rasbt avatar Mar 06 '24 21:03 rasbt

just curious: will DoRA be integrated into lit-gpt?

cmhungsteve avatar Mar 26 '24 01:03 cmhungsteve

will DoRA be integrated into lit-gpt?

Sorry for the late response. I currently have some other work projects I have to get to first, but yeah, ultimately, I am hoping to add it to LitGPT!

We recently did a big API overhaul of LitGPT that we wanted to get done first so that LitGPT can be used as a CLI, e.g.

litgpt finetune lora --config config.yaml --other_arg xyz

(See the new Zero to LitGPT guide for more examples)

It could be

litgpt finetune dora

then. I don't have a timeline yet but I'm hoping to add that some time in the upcoming weeks.

rasbt avatar Apr 02 '24 13:04 rasbt

@rasbt I was able to reproduce the magnitude and directional update visualizations for DoRA and LoRA using Llama 2 7B, wrote a blog post recently on that. Link attached below. https://shreyassk.substack.com/p/visualising-dora-weight-decomposed

shreyassks avatar Apr 03 '24 04:04 shreyassks

This looks awesome! Thanks for sharing. Just saved & sent it to my e-reader and plan to reading it carefully in the upcoming days!

rasbt avatar Apr 04 '24 16:04 rasbt

@rasbt I was able to reproduce the magnitude and directional update visualizations for DoRA and LoRA using Llama 2 7B, wrote a blog post recently on that. Link attached below. https://shreyassk.substack.com/p/visualising-dora-weight-decomposed

@shreyassks : Is the 'Negative correlation ' of Query-Q & Values-V matrices seen based on your experiment's outcome with DORA Visualization as per above Link not going to affect the causation behaviour in the Masked self attention mechanism. As Q & V matrices are used in the final decoder stage equation as per==> (Dropout (Masked Attention(Softmax(QK^T) V)) . Did you look into this ?

akramIOT avatar Apr 09 '24 05:04 akramIOT