speechbrain icon indicating copy to clipboard operation
speechbrain copied to clipboard

Adding adapters to SpeechBrain (Code from Samsung AI Center Cambridge)

Open TParcollet opened this issue 9 months ago • 10 comments

What does this PR do?

Based on this #2526, this PR is a first attempt at adding Adapters to any SB model. This will only work if PreTrainer is used and not checkpointer. Indeed, the checkpointer will try to reload after the state_dict has been modified. So you need to 1. Instanciate the Brain, 2. Call the PreTrainer; 3. add the adapters; 4. call fit. An example is:

`

    asr_brain = ASR(
        modules=hparams["modules"],
        opt_class=hparams["Adam"],
        hparams=hparams,
        run_opts=run_opts,
        checkpointer=hparams["checkpointer"],
    )

    # adding objects to trainer:
    asr_brain.tokenizer = hparams["tokenizer"]

    # Load the pretrained model.
    run_on_main(hparams["pretrainer"].collect_files)
    hparams["pretrainer"].load_collected()

    from speechbrain.lobes.models.Adapters import (
        add_adapters_to_linear_in_model,
    )
    from speechbrain.lobes.models.Adapters import HoulsbyAdapterLinear

    add_adapters_to_linear_in_model(
        model=asr_brain.modules.Transformer,
        adapter_class=HoulsbyAdapterLinear,
        projection_size=32,
    )

    # Training
    asr_brain.fit(
        asr_brain.hparams.epoch_counter, train_data, valid_data,
    )

`

TParcollet avatar Apr 30 '24 18:04 TParcollet

I don't think this is quite the right approach because I don't think it allows for stopping/restarting which is part of the point of checkpointing. Instead, the checkpointer should store the LoRA'd model, not the pretrained model. Ideally it would even only store the LoRA weights (and any updated weights) and not the whole model, making for very small checkpoints and faster saving. Example:

add_adapters: !name:speechbrain.lobes.models.Adapters.add_adapters_to_linear_in_model
  adapter_class: !name:speechbrain.lobes.models.Adapters.HoulsbyAdapterLinear
  projection_size: 32

pretrainer: !new:speechbrain....Pretrainer
  loadables:
    transformer: !ref <Transformer>
# Load the pretrained model.
run_on_main(hparams["pretrainer"].collect_files)
hparams["pretrainer"].load_collected()
hparams["add_adapters"](hparams["Transformer"])

asr_brain = ASR(
    modules=hparams["modules"],
    opt_class=hparams["Adam"],
    hparams=hparams,
    run_opts=run_opts,
    checkpointer=hparams["checkpointer"], # Checkpointer loads LoRA weights only and applies them
)

asr_brain.fit(
    asr_brain.hparams.epoch_counter, train_data, valid_data,
)

pplantinga avatar May 02 '24 15:05 pplantinga

It does allow for stop and restart because you are altering the object i.e. the checkpointer keeps track of it! The only problem is indeed that you store the whole model, however, I don't think it's an issue because in-fine you may simply don't know where to put the pre-trained adapters in the model if they are not applied to every linear layer for instance. I'd be happy to see a functionnal example of something else though.

TParcollet avatar May 02 '24 17:05 TParcollet

Tbh I think PEFT handles this perfectly, perhaps we should lift their code wholesale.

pplantinga avatar May 02 '24 22:05 pplantinga

You mean depend on another Huggingface library?

TParcollet avatar May 03 '24 08:05 TParcollet

My opinion is we should just add it as a dependency, but I understand the objections to it. So instead we could just copy the parts of the code that make sense into speechbrain.

pplantinga avatar May 03 '24 12:05 pplantinga

If you could give me a neat example of an integration of PEFT, I could be convinced.

TParcollet avatar May 03 '24 12:05 TParcollet

The problem with peft is when we want to load the model from the speechbrain checkpoint. It is a mess to make it work and also it could cause the problem when using different version of peft. But maybe we could find a way to do this in a cleaner way.

poonehmousavi avatar May 03 '24 15:05 poonehmousavi

@poonehmousavi @pplantinga @mravanelli @Adel-Moumen @asumagic what should we do here?

TParcollet avatar May 23 '24 17:05 TParcollet

I'm working on this, but I've been traveling this week and last. I hope to have a reasonable PR by the end of next week

pplantinga avatar May 23 '24 17:05 pplantinga

@TParcollet I am also a bit busy with the Neurips deadline.(June 5th) . but after that, I could actively work on that

poonehmousavi avatar May 24 '24 17:05 poonehmousavi

Superseded by #2563

pplantinga avatar Sep 10 '24 18:09 pplantinga