peft icon indicating copy to clipboard operation
peft copied to clipboard

Implement adaption prompt from Llama-Adapter paper

Open yeoedward opened this issue 1 year ago • 12 comments

Issue #235 Paper: https://arxiv.org/abs/2303.16199 Official Implementation: https://github.com/ZrrSkywalker/LLaMA-Adapter

This implementation is quite specific to Llama although in principle the same technique can be used for other transformer models. It is implemented by replacing some of the LlamaAttention modules with AdaptedAttention modules that wrap the original ones, but bypass the original forward() so that we can inject trainable tokens and a gate. I can add a notebook example in a subsequent PR when I have the time.

yeoedward avatar Apr 05 '23 21:04 yeoedward

The documentation is not available anymore as the PR was closed or merged.

Fixed black formatting and merge conflicts.

yeoedward avatar Apr 06 '23 07:04 yeoedward

CI test failures were caused by Llama dependencies not being available in earlier versions of the transformers library. To ensure backwards compatibility, I've done the following:

  1. Added import guards for llama classes from transformers. This will allow peft to be used with earlier versions of transformers without breaking.
  2. Skip unit tests that require those classes if the classes are not found. CI tests should pass.

yeoedward avatar Apr 06 '23 12:04 yeoedward

Rebased on main and added support for multi-adapters

yeoedward avatar Apr 07 '23 17:04 yeoedward

Thanks @pacman100 for the review! Regarding your comment about not using _modified_forward():

  1. To make the implementation generic, having attention_target_modules as another config which is a List[nn.Module]. The AdaptedAttention module's new forward should call the normal Attention layer's forward without any changes (so no modified_forward). Then use the attention_target_modules to get the query_states, adapter_k and adapter_v. Here, if there is only single fused_qkv as in gpt-2/bloom, just use outputs.split(3, dim=2) and do the necessary assignments.

query_states is not actually returned by the original forward() method, it is a local variable within that function. So if we want to reuse the original forward() without modification, I think we will have to recompute query_states which results in a performance penalty, and to do that we will still need to import the Llama-specific apply_rotary_pos_emb() to apply on the query_states. So the implementation for Llama will still be Llama-specific. From my understanding, I think we have following two options:

  1. Refactor LlamaAttention in transformers to return query_states. This will allow us to keep this library generic. But we might have to refactor other models in transformers to conform to our interface too.
  2. Have a separate code path for Llama and other models. Two options: a. _modified_forward() which completely bypasses the original forward() b. Recompute query_states, which will require importing the Llama-specific apply_rotary_pos_emb() (less duplication of code but with the perf penalty of recomputation).

How do you think we should proceed?

yeoedward avatar Apr 11 '23 08:04 yeoedward

Recompute query_states, which will require importing the Llama-specific apply_rotary_pos_emb() (less duplication of code but with the perf penalty of recomputation).

I am leaning towards this as this would be generic and recomputation penalty would be minimal. There can be a model to postprocessing function mapping which could be used for things like apply_rotary_pos_emb . Refer this code block for one such example:

https://github.com/huggingface/peft/blob/main/src/peft/peft_model.py#L246-L249

pacman100 avatar Apr 11 '23 15:04 pacman100

@pacman100 I've made the changes you suggested and rebased on main.

yeoedward avatar Apr 12 '23 15:04 yeoedward

@yeoedward have you been able to validate the implementation results against the paper's results? I gave your branch a quick whirl with the same hyper parameters as the paper, but it definitely doesn't seem to converge as quickly.

winglian avatar Apr 13 '23 11:04 winglian

@winglian I have validated the implementation on smaller problems that run on a CPU but haven't tried reproducing the paper's results at full problem size. When you say it doesn't converge as quickly, are you referring to number of epochs trained? Happy to look into it if you can provide a notebook / script, although I don't have access to 8 x A100s.

yeoedward avatar Apr 13 '23 18:04 yeoedward

@pacman100 I've made the suggested changes and rebased on main.

yeoedward avatar Apr 14 '23 16:04 yeoedward

Could you please run make style and make quality to fix the quality issues?

pacman100 avatar Apr 20 '23 08:04 pacman100

@pacman100 I couldn't reproduce the style error (which was in a file unrelated to this PR), but I rebased on main and fixed some conflicts, which perhaps was the reason for the style checker failing. I've re-run the unit tests and quality checks, so hopefully this is good to go. Thanks for reviewing!

yeoedward avatar Apr 20 '23 14:04 yeoedward

Thank you @yeoedward! 🤗

pacman100 avatar Apr 25 '23 06:04 pacman100

It took me a bit of time to find that the paper has been implemented, given that it is present in the list of the Supported methods. However, I haven't seen any documentation in the Tuners page here. If, I am correct, it has been implemented in this file. The method explained in the paper is model agnostic. So will it support falcon, instead of llama? There has been a recent implementation of this method which can be used for models other than llama. You can check it here.

tathagata-raha avatar Jul 17 '23 06:07 tathagata-raha