peft icon indicating copy to clipboard operation
peft copied to clipboard

Integrating ReFT

Open raven38 opened this issue 10 months ago • 8 comments

Paper link: https://arxiv.org/abs/2404.03592

This PR integrates ReFT by creating a new tuner model type. Please see #1654.

Changes

I defined a new class ReFTModel as a subclass of LycorisTuner. src/peft/tuners/reft/layer.py has the actual operations for ReFT methods.

Status

  • [ ] Discuss specification
  • [ ] Refactor implementation
  • [ ] Test implementation
  • [ ] Add documentation for methods.

raven38 avatar Apr 16 '24 09:04 raven38

@raven38 We just merged VeRA, which resulted in a bunch of merge conflicts. They should be very straightforward to resolve, let me know if you have any questions. Also ping me once this is ready for the next review.

BenjaminBossan avatar Apr 19 '24 09:04 BenjaminBossan

The authors describe 4 important hyper parameters in 4.1: Number of prefix positions and suffix positions exists via loc parameters. I borrow the pyreft implementation for parsing the positions. As you say, set of layers to target exists via target_modules. We can implement untied intervention by defining two ReFT models same with pyreft. The default ReFT model is tied intervention.

raven38 avatar Apr 28 '24 16:04 raven38

Sorry, I took the wrong operation, which fork sync with discards my commits, closing this pull request. Can I reopen it?

raven38 avatar May 20 '24 06:05 raven38

I think the code is ready for the review.

raven38 avatar May 21 '24 00:05 raven38

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@raven38 Thanks for the updates, I'm currently reviewing the PR, but it'll take some time. Meanwhile, could you please run make style on your code so that the CI passes?

BenjaminBossan avatar May 27 '24 12:05 BenjaminBossan

I modified codes so that make style passes, but make style still raises the following error. Should it be fixed by this PR?

examples/boft_controlnet/test_controlnet.py:40:1: E402 Module level import not at top of file
Found 1 error.
make: *** [style] Error 1

raven38 avatar May 27 '24 12:05 raven38

Hmm, code quality checks still fail:

ruff src tests examples docs scripts docker ruff format --check src tests examples docs scripts docker Would reformat: src/peft/tuners/reft/config.py Would reformat: src/peft/tuners/reft/layer.py Would reformat: src/peft/tuners/reft/model.py Would reformat: tests/test_custom_models.py Would reformat: tests/test_stablediffusion.py Would reformat: tests/testing_common.py 6 files would be reformatted, 159 files already formatted

This does not appear to be related to anything in test_controlnet, which is independent of your PR. Is it possible that you use a different ruff version? The CI uses v0.2.2.

BenjaminBossan avatar May 27 '24 12:05 BenjaminBossan

hey! thanks again for the PR (i am one of the authors of the ReFT paper). i skim through the code change, it looks very promising.

two quick questions:

  1. with the current peft config, could we target any residual stream with loreft? if so, what will the config look like?
  2. if i add a loreft adaptor in and given we only intervene on the prompt tokens, during the decoding process, the adaptor will not have an effect right?

thanks!

frankaging avatar May 30 '24 08:05 frankaging

@frankaging Thanks a lot for taking a look. Regarding your questions, I'll let @raven38 provide the final answer. Personally, I think 2 is right and I don't quite understand 1, maybe you can give an example?

BenjaminBossan avatar May 30 '24 08:05 BenjaminBossan

@frankaging Thanks a lot for taking a look. Regarding your questions, I'll let @raven38 provide the final answer. Personally, I think 2 is right and I don't quite understand 1, maybe you can give an example?

Thanks for the update! For (1), i was wondering: for example, when targeting the attention output matrix for LoRA of a Llama model, we might have a config like this:

from peft import LoraConfig, get_peft_model

peft_config = LoraConfig(
    r=4, lora_alpha=32, target_modules=["o_proj"], layers_to_transform=[15],
    use_rslora=True, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM"
)
model = get_peft_model(model, peft_config)

where target_modules=["o_proj"]. Since for LoReFT, we want to target the residual stream (transformer block/layer output), what should we put for the target_modules in that case? Thanks!

frankaging avatar May 30 '24 18:05 frankaging

Since for LoReFT, we want to target the residual stream (transformer block/layer output), what should we put for the target_modules in that case? Thanks!

I see, I think I get what you mean. Good question. I think as is, it would not work. We would have to target the attention module itself (the type depending on what kind of attention is being used), targeting a Linear layer would not work. I think this is what you're getting at, right?

If you have some pointers on how this is implemented in pyreft/pyvene, please share.

BenjaminBossan avatar May 31 '24 12:05 BenjaminBossan

Since for LoReFT, we want to target the residual stream (transformer block/layer output), what should we put for the target_modules in that case? Thanks!

I see, I think I get what you mean. Good question. I think as is, it would not work. We would have to target the attention module itself (the type depending on what kind of attention is being used), targeting a Linear layer would not work. I think this is what you're getting at, right?

If you have some pointers on how this is implemented in pyreft/pyvene, please share.

Thanks for your reply. Yes, we need to target the whole block module itself (e.g., model.layers[15]). Will this be doable? If not, I think it is still valuable to have the current form where we target submodules (e.g., attention output or MLP output).

With pyreft or pyvene, we use register_module_forward_hook to place a callback to the existing computation graph, instead of overwriting current modules with the additional adapter modules. This is quite different, so it might not be that helpful here.

frankaging avatar May 31 '24 18:05 frankaging

Thanks for your reply. Yes, we need to target the whole block module itself (e.g., model.layers[15]). Will this be doable?

Yes, I don't see why not. We would have to add a new layer type, similar to class Linear(LoReftLayer), but for the whole attention block layer. The issue there is probably how well this generalizes across different attention types and architectures. With Linear, we know they're used everywhere, but there are quite a few different attention layers implementations. So we'd probably have to focus on a few popular ones, like LlamaAttention.

With pyreft or pyvene, we use register_module_forward_hook to place a callback to the existing computation graph, instead of overwriting current modules with the additional adapter modules.

We do use forward hooks in PEFT when necessary but if there is an easier way, we generally prefer that. Do you have a pointer in your repos on how it's implemented?

BenjaminBossan avatar Jun 03 '24 09:06 BenjaminBossan

@BenjaminBossan Thanks! It would be great if we could support LlamaAttention on residual stream intervention.

On hook-based impl: ReFT is one of the use-cases of our pyvene library (and actually a simple use-case). It might be hard for me to lay out the full details of pyvene. But, we essentially, get a intervention config from the user (e.g., pointing out which layer, which stream, what rank, etc..), and then we manage these interventions for users by registering hooks on-to (and removing hooks off-from) the model, and memorizing activations (e.g., https://github.com/stanfordnlp/pyvene/blob/main/pyvene/models/intervenable_base.py#L963 <- this is line of registering our hooks).

For a residual stream intervention with r=4 on layer 15 of a llama-based LM, the config will be something like:

reft_config = pyreft.ReftConfig(representations={
    "component": "model.layers[15].output",
    "intervention": pyreft.LoreftIntervention(embed_dim=model.config.hidden_size,
    low_rank_dimension=4)})
reft_model = pyreft.get_reft_model(model, reft_config)

And then internally, we basically call model.layers[15].register_module_forward_hook(pyreft.LoreftIntervention(...), ...) with the intervention torch module.

frankaging avatar Jun 03 '24 22:06 frankaging

Thanks for the pointer (here is the permalink just in case). Let's wait for @raven38's opinion is on this.

BenjaminBossan avatar Jun 04 '24 12:06 BenjaminBossan

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

github-actions[bot] avatar Jun 28 '24 15:06 github-actions[bot]