peft
peft copied to clipboard
Implement adaption prompt from Llama-Adapter paper
Issue #235 Paper: https://arxiv.org/abs/2303.16199 Official Implementation: https://github.com/ZrrSkywalker/LLaMA-Adapter
This implementation is quite specific to Llama although in principle the same technique can be used for other transformer models. It is implemented by replacing some of the LlamaAttention
modules with AdaptedAttention
modules that wrap the original ones, but bypass the original forward()
so that we can inject trainable tokens and a gate. I can add a notebook example in a subsequent PR when I have the time.
The documentation is not available anymore as the PR was closed or merged.
Fixed black
formatting and merge conflicts.
CI test failures were caused by Llama dependencies not being available in earlier versions of the transformers
library.
To ensure backwards compatibility, I've done the following:
- Added import guards for llama classes from
transformers
. This will allowpeft
to be used with earlier versions oftransformers
without breaking. - Skip unit tests that require those classes if the classes are not found. CI tests should pass.
Rebased on main and added support for multi-adapters
Thanks @pacman100 for the review! Regarding your comment about not using _modified_forward()
:
- To make the implementation generic, having attention_target_modules as another config which is a List[nn.Module]. The AdaptedAttention module's new forward should call the normal Attention layer's forward without any changes (so no modified_forward). Then use the attention_target_modules to get the query_states, adapter_k and adapter_v. Here, if there is only single fused_qkv as in gpt-2/bloom, just use outputs.split(3, dim=2) and do the necessary assignments.
query_states
is not actually returned by the original forward()
method, it is a local variable within that function. So if we want to reuse the original forward()
without modification, I think we will have to recompute query_states
which results in a performance penalty, and to do that we will still need to import the Llama-specific apply_rotary_pos_emb()
to apply on the query_states
. So the implementation for Llama will still be Llama-specific. From my understanding, I think we have following two options:
- Refactor
LlamaAttention
intransformers
to returnquery_states
. This will allow us to keep this library generic. But we might have to refactor other models intransformers
to conform to our interface too. - Have a separate code path for Llama and other models. Two options:
a.
_modified_forward()
which completely bypasses the originalforward()
b. Recomputequery_states
, which will require importing the Llama-specificapply_rotary_pos_emb()
(less duplication of code but with the perf penalty of recomputation).
How do you think we should proceed?
Recompute query_states, which will require importing the Llama-specific apply_rotary_pos_emb() (less duplication of code but with the perf penalty of recomputation).
I am leaning towards this as this would be generic and recomputation penalty would be minimal.
There can be a model to postprocessing function mapping which could be used for things like apply_rotary_pos_emb
. Refer this code block for one such example:
https://github.com/huggingface/peft/blob/main/src/peft/peft_model.py#L246-L249
@pacman100 I've made the changes you suggested and rebased on main.
@yeoedward have you been able to validate the implementation results against the paper's results? I gave your branch a quick whirl with the same hyper parameters as the paper, but it definitely doesn't seem to converge as quickly.
@winglian I have validated the implementation on smaller problems that run on a CPU but haven't tried reproducing the paper's results at full problem size. When you say it doesn't converge as quickly, are you referring to number of epochs trained? Happy to look into it if you can provide a notebook / script, although I don't have access to 8 x A100s.
@pacman100 I've made the suggested changes and rebased on main.
Could you please run make style
and make quality
to fix the quality issues?
@pacman100 I couldn't reproduce the style error (which was in a file unrelated to this PR), but I rebased on main
and fixed some conflicts, which perhaps was the reason for the style checker failing. I've re-run the unit tests and quality checks, so hopefully this is good to go. Thanks for reviewing!
Thank you @yeoedward! 🤗
It took me a bit of time to find that the paper has been implemented, given that it is present in the list of the Supported methods. However, I haven't seen any documentation in the Tuners page here. If, I am correct, it has been implemented in this file. The method explained in the paper is model agnostic. So will it support falcon, instead of llama? There has been a recent implementation of this method which can be used for models other than llama. You can check it here.