peft
peft copied to clipboard
[WIP] Update `LoraConfig` for KaSA implementation
cc @BenjaminBossan
I was delayed in updating the code because I was focusing on company work, but now I'm planning to resume the project in earnest. If I have any questions about implementing the code, may I continue to ask you?
I apologize for opening a new pull request, as the previous one was closed 🥲 Thank you for your understanding.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
gentle ping @nsbg
Thank you for your alert!
I spent some time looking over the KaSA paper and code to get ready for more serious work, but it does seem pretty difficult 🥲 My goal is to upload code that's ready for review before the end of September, so I'm going to try even harder.
Right now, I'm stuck at the 'Extend LoRA variant resolution' stage you mentioned. Honestly, this seems like the most important part, but it's hard for me to figure out where to start—specifically, which file and class I should work on first. Could you help me with this?
That's great to see, thanks for picking this back up.
Right now, I'm stuck at the 'Extend LoRA variant resolution' stage you mentioned. Honestly, this seems like the most important part, but it's hard for me to figure out where to start—specifically, which file and class I should work on first. Could you help me with this?
You're already on the right track, you added KasaLinearVariant, which is the most important step. There are definitely some changes required there, as there is some code that is only relevant for DoRA and can be removed for KaSA. But we can leave that as is for now.
Next about resolving the variants. As a first step, let's revert the changes you made to lora/layer.py and start fresh. We don't need a self.use_kasa attribute, we only have self.use_dora for backwards compatibility, as we didn't have LoRA variants when we first implemented DoRA.
Then let's look at these lines in lora.Linear:
https://github.com/huggingface/peft/blob/a3197b1ec54baa107e2efc4a30c19209d9ea489a/src/peft/tuners/lora/layer.py#L636-L642
Here we need to extend the functionality to add KaSA. The updated method could be something like:
def resolve_lora_variant(self, *, use_dora: bool, use_kasa: bool, **kwargs) -> Optional[LoraVariant]:
if use_dora and use_kasa:
raise ValueError("Cannot use DoRA and KaSA at the same time, please choose only one.")
variant = None
if use_dora:
from .variants import DoraLinearVariant
variant = DoraLinearVariant()
elif use_kasa:
...
return variant
Does that make sense? Similarly, we'd have to update the resolve_lora_variant methods of other LoRA layers, depending on whether they work with KaSA or not (I'm not sure if KaSA works with Conv2d etc.).
I would suggest that you work on this as a next step, then we'll see what else needs to be done.
wow I really appreciate your sincere feedback. I'll read your advice carefully and then move forward 🤗
@BenjaminBossan I modified the code in the files below based on what you explained. Please give me feedback if there are parts that still need fixing, and then we can discuss the next steps.
1. variants.py
- Completed updates to methods in the
KasaLinearVariantsclass
2. layer.py
-
In the LoraLayer class, added
self.use_kasa[adapter_name] = use_kasainside the update_layer method -
In the Linear class, added KaSA handling logic inside the
get_delta_weightmethod
@BenjaminBossan oh I didn't mean to close the branch, but it seems to have closed while I was merging with the main branch. I guess I'll have to open a new PR, right? 😰
+) when I tried to sync with the main branch, I ended up discarding all my commits, so did that cause it to close?
oh I didn't mean to close the branch, but it seems to have closed while I was merging with the main branch. I guess I'll have to open a new PR, right? 😰
+) when I tried to sync with the main branch, I ended up discarding all my commits, so did that cause it to close?
I don't know what happened, but I could re-open the PR and there are some changes visible. Can you double check that everything looks as expected? If for some reason it's not what it's expected, you can create a new PR and push your local branch.
I usually handle merges in the terminal, and I suspect the pull request was closed because I accidentally wiped the commit history while using the 'Sync fork' feature on GitHub. I'll be more careful in the future. Thanks for reopening it.
I'll review the changes and open a new PR if needed. Sorry to keep bothering you with this.
I'll review the changes and open a new PR if needed. Sorry to keep bothering you with this.
No worries. If the diff on this PR looks good, let me know and I'll do a review. Only open a new PR if for some reason, the code here does not correspond to what it should be.
@BenjaminBossan I checked layer.py/variants.py and KasaLinearVariants class in variants.py was removed. I added it again and I updated file based on your minor feedback, so I think we can discuss in this PR continually.
BTW I ran make style command and got this error.
make style
ruff check --fix src tests examples docs scripts docker
process_begin: CreateProcess(NULL, ruff check --fix src tests examples docs scripts docker, ...) failed.
I ran pip install -e .[test] command in https://huggingface.co/docs/peft/install#source, but I got same error. Do I just run that command directly without needing to set up a virtual environment?
maybe make style related error was fixed. After applying this command, quite a few files have changed. Is it okay to just push them? Also, what exactly does make style do?
maybe
make stylerelated error was fixed. After applying this command, quite a few files have changed. Is it okay to just push them? Also, what exactly doesmake styledo?
No, let's not push any changes to unrelated files. If make style changes unrelated files, it's often to one of these reasons:
- Wrong
ruffversion: check that v0.12.12 is installed in your virtual environment. - Not picking up the config: There are some settings for ruff in the
pyproject.tomlfile, ensure that it's there when you runmake style.
@BenjaminBossan
Also, we should update the resolve_lora_variant methods of the other layer types like lora.Embedding.resolve_lora_variant to accept the use_kasa argument but raise an error if it's True. Otherwise, users may add it to non-supported layers and not notice that it doesn't actually do anything there.
I referred to your explanation and added the use_kasa parameter to the resolve_lora_variant method in the classes below.
- Linear (line 702)
- Embedding (line 934)
- _ConvNd (line 1263)
- Conv2d (line 1507)
- Conv1d (line 1524)
- Conv3d (line 1541)
The logic for raising errors in each layer hasn’t been applied yet, but I committed first to check whether adding the parameter in this way matches what you meant. Excluding the Linear class, it seems like an error should be raised when use_kasa is true in the other classes. However, I might be mistaken, so please feel free to give me feedback anytime. Also, I noticed there’s no part that calls KasaLinearVariant—should this be called inside the linear class? I’m a bit confused about this part.
@BenjaminBossan
I removed unnecessary parts and added the _get_delta_weight static method specifically for the KasaLinearVariants class. If adding _get_delta_weight in this way is correct, I plan to update all occurrences of module.get_delta_weight within the current KasaLinearVariants class to use the modified method. Please review and share your thoughts.
I've made the changes you commented on and added the following to the test file. I used the name 'Vanilla MLP LoRA with KaSA' for the lines added to the test file, which is different from what you suggested. Please let me know if this wasn't your intention.
Also, I have a question: in forward method,
def forward(module: Linear, active_adapter: str, x: torch.Tensor, result: torch.Tensor) -> torch.Tensor:
lora_A = module.lora_A[active_adapter]
lora_B = module.lora_B[active_adapter]
dropout = module.lora_dropout[active_adapter]
scaling = module.scaling[active_adapter]
diag = torch.diag(module.lora_diag[active_adapter])
# KASA calculation
# see https://github.com/juyongjiang/KaSA/blob/f85e88c22d0fa4cb8ab2923d7c2bf1bbec152da3/peft/src/peft/tuners/lora/layer.py#L602C21-L602C110
lora_output = lora_B(torch.einsum("ijk,kl->ijl", lora_A(x), diag)) * scaling
return result + lora_output
dropout variable isn't being used. Should I remove it or just leave it as is?
dropoutvariable isn't being used. Should I remove it or just leave it as is?
Good catch. I think the correct way is to apply it to x, like in the original code: https://github.com/juyongjiang/KaSA/blob/f85e88c22d0fa4cb8ab2923d7c2bf1bbec152da3/peft/src/peft/tuners/lora/layer.py#L602C21-L602C110
I've added up to this point. Please share your thoughts if there are any further modifications needed in the work so far, and we can discuss the next steps for the testing phase.
I have a question after local testing. I'll briefly share the test results and process.
1. 10 failures occurred
- 8 dtype mismatch issues, 2 AssertionError
2. dtype issues (all resolved ✅)
- When testing cases where the model's data type changes, like in the
test_forward_bfloat16method, dtype mismatch issues occurred. - This was resolved by adding a part to match the dtype in a method within the
KasaLinearVariantsclass.
3. AssertionError issues
-
An assert False error occurred in the
test_disable_adaptersmethod. -
So, I added the following conditional statement to the forward method so that
outputs_before == outputs_disabledholds True in thetest_disable_adapters, by skipping KaSA calculation and returning only the base model output when adapters are disabled.# Check if adapters are disabled if module.disable_adapters: return result -
However, assert False is still occurring.
But now that I think about it, looking at the current KasaLinearVariants class, it applies SVD to the weights and calculates from the init method. So I'm wondering if assert False is actually correct, or if I just tried to solve it in the wrong way.
Sorry for the delayed update; I've been working on this in bits and pieces due to personal scheduling. Since some time has passed, I'll summarize the feedback you previously gave me and my related work:
1. _skip_test_disable_adapters function
- Added skip logic for test functions that do not work with KaSA adapters. (you can check in
tests/test_custom_models.py)
2. TestKasaInitialization class
test_kasa_mixed_adapter_errormethod: When existing adapters are not KaSA, adding a new KaSA adapter should raise a ValueError.test_kasa_mixed_adapter_error_reversemethod: When existing adapters are KaSA, adding a non-KaSA adapter should raise a ValueError.
Both of these test functions passed because they successfully raised a ValueError. My understanding is that "PASSED" here means the ValueError was raised, which is the intended behavior for these functions. Please correct me if my understanding is wrong.
3. Logic to avoid repetitive SVD processes (Just an idea 💡)
I agree with your opinion that repeatedly modifying the base weights for the same KaSA adapter is not ideal. My current idea is to apply a cache, as shown below, and check if SVD has already been applied by checking the existence of this attribute. What do you think of this approach?
if not hasattr(module, "_kasa_svd_cache"):
# First KaSA adapter: perform SVD and cache the result
weight = module.get_base_layer().weight
dtype = weight.dtype
weight = weight.to(torch.float32)
U, S, Vh = torch.linalg.svd(weight.data, full_matrices=False)
module._kasa_svd_cache = (U, S, Vh)
else:
# Reuse cached SVD result
U, S, Vh = module._kasa_svd_cache
dtype = module.get_base_layer().weight.dtype