llm-foundry icon indicating copy to clipboard operation
llm-foundry copied to clipboard

Not an issue, a question - Peft/LoRa finetuning a possibility?

Open jamesd256 opened this issue 2 years ago • 14 comments

Another noob question... Is is possible to reduce the resource burden for fine tuning by using Peft/LoRa techniques?

If not will it be possible in the future with MPT models?

jamesd256 avatar May 07 '23 21:05 jamesd256

Been thinking of this too. A quick test on the following was successful. However, I haven't yet proceeded the training. Maybe you could have a trial on this based on alpaca-lora project.

    r=8,
    lora_alpha=16,
    target_modules=["up_proj", "down_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)
model1 = get_peft_model(model, config)

lkluo avatar May 08 '23 10:05 lkluo

I am trying the same using Huggingface trainer, but get an error TypeError: MPTForCausalLM.forward() got an unexpected keyword argument 'inputs_embeds while training.

PrakharSaxena24 avatar May 09 '23 05:05 PrakharSaxena24

Hi all, we don't have any immediate plans for PEFT/LoRA support, but if there is an easy way to edit our MPT models fit the HF ecosystem for these workflows, we would be happy to do so (or accept a PR that does!)

It looks like we may need to add a couple more kwargs to our forward function, is there anything else? Any additional info / links would be useful.

abhi-mosaic avatar May 09 '23 16:05 abhi-mosaic

Thanks for the reply! I could get around the issue using.

`class IgnoreEmbedsWrapper(nn.Module): def init(self, base_model: nn.Module): super().init() self.base_model = base_model self.config = base_model.config def forward( self, input_ids=None, attention_mask=None, inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None, **kwargs, ):

    return self.base_model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        labels=labels,                                           # Remove inputs_embeds, which are unused by the base model
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
        **kwargs,
     )

model = IgnoreEmbedsWrapper(model)`

However while training the loss does not decrease. training script.

`global_step = 0 for epoch in range(num_epochs): model.train() print(f"Epoch: {epoch + 1}")

for i, batch in tqdm_notebook(enumerate(train_dataloader),total=len(train_dataloader)):

    input_ids = batch["input_ids"].to("cuda")
    attention_mask = batch["attention_mask"].to("cuda")
    labels = batch["labels"].to("cuda")

    optimizer.zero_grad()
    outputs = model(input_ids, attention_mask=attention_mask, labels=labels, inputs_embeds=None)
    loss = outputs.loss
    loss.backward()
    optimizer.step()

    if (i+1) % 100 == 0:
        print(f"  Step {i}, Loss: {loss.item()}")
        wandb.log({"Train_loss": loss.item()})

    if global_step % eval_interval == 100:
        avg_eval_loss = evaluate()
        print(f"  Evaluation loss: {avg_eval_loss}")


    global_step += 1`
    
    I cannot figure out why the loss(training) is not decreasing.

PrakharSaxena24 avatar May 10 '23 05:05 PrakharSaxena24

Which file did you modify? @PrakharSaxena24

Smart-Tom avatar May 10 '23 05:05 Smart-Tom

@Smart-Tom I don`t understand your question, I did not modify any file.

PrakharSaxena24 avatar May 10 '23 05:05 PrakharSaxena24

Even if we PR for this github repository, this PR will not sync with source code inside huggingface repository, I did PR for huggingface repository, https://huggingface.co/mosaicml/mpt-7b/discussions/11

huseinzol05 avatar May 10 '23 08:05 huseinzol05

if you need a solution to finetune with LoRa / 8bit loading, we did it on OpenNMT-py. see here: https://forum.opennmt.net/t/finetuning-llama-7b-or-mosaicml-mpt-7b-reproduce-vicuna-alpaca/5272/20

vince62s avatar May 12 '23 09:05 vince62s

Ya'll may find this script helpful:

import argparse
import loralib as lora
import transformers
from tqdm import tqdm

def lora_process(model_name, max_seq_len, attn_impl, r_emb, r):
    print("Loading model configurations...")
    config = transformers.AutoConfig.from_pretrained(
        f'mosaicml/{model_name}',
        trust_remote_code=True
    )
    config.attn_config['attn_impl'] = attn_impl
    config.update({"max_seq_len": max_seq_len})
    config.update({"alibi": True})

    print("Loading model...")
    model = transformers.AutoModelForCausalLM.from_pretrained(
        f'mosaicml/{model_name}',
        config=config,
        trust_remote_code=True
    )

    print("LoRAfying the embeddings...")
    wte = lora.Embedding(
        model.transformer.wte.num_embeddings,
        model.transformer.wte.embedding_dim,
        r = r_emb
    )
    wte.weight.data = model.transformer.wte.weight.data.clone()
    model.transformer.wte = wte

    print("LoRAfying the attention layers...")
    for i, block in tqdm(enumerate(model.transformer.blocks), total=len(model.transformer.blocks), desc="Processing blocks"):
        Wqkv = lora.MergedLinear(
            block.attn.d_model,
            block.attn.d_model * 3,
            r=r,
            enable_lora=[True, False, True]
        )
        Wqkv.weight.data = block.attn.Wqkv.weight.data.clone()
        block.attn.Wqkv = Wqkv

        out_proj = lora.Linear(
            block.attn.out_proj.in_features,
            block.attn.out_proj.out_features,
            r=r
        )

        out_proj.weight.data = block.attn.out_proj.weight.data.clone()
        block.attn.out_proj = out_proj

        up_proj = lora.Linear(
            block.ffn.up_proj.in_features,
            block.ffn.up_proj.out_features,
            r=r
        )
        down_proj = lora.Linear(
            block.ffn.down_proj.in_features,
            block.ffn.down_proj.out_features,
            r=r
        )

        up_proj.weight.data = block.ffn.up_proj.weight.data.clone()
        down_proj.weight.data = block.ffn.down_proj.weight.data.clone()

        block.ffn.up_proj = up_proj
        block.ffn.down_proj = down_proj

    print("Process completed successfully.")
    return model

def main():
    parser = argparse.ArgumentParser(description='Script to LoRAfy a transformer model.')
    parser.add_argument('-m', '--model', default='mpt-7b', type=str, help='The name of the pretrained model.')
    parser.add_argument('-s', '--seq_len', default=8192, type=int, help='Maximum sequence length.')
    parser.add_argument('-re', '--r_emb', default=256, type=int, help='LoRA rank for the embeddings.')
    parser.add_argument('-r', '--r', default=32, type=int, help='LoRA rank for the attention layers and MLP.')
    args = parser.parse_args()

    lora_process(args.model, args.seq_len, args.r_emb, args.r)

if __name__ == "__main__":
    main()

EDIT:

FYI --

The wte embedding layer is used for de-embedding as well, which is fine, except that it's applied in a somewhat nonstandard way in the fwd pass of the transformer module. This means that this particular LoRA "exoskeleton" won't be updating the de-embedding layer which could hurt performance? Seems fixable...

tginart avatar May 15 '23 07:05 tginart

Hi all, just want to follow up and say we are working to make the MPT model code class more amenable to PEFT/LoRA, and should have updates ~ next week.

abhi-mosaic avatar May 17 '23 22:05 abhi-mosaic

except that it's applied in a somewhat nonstandard way in the fwd pass of the transformer module

@tginart Can you say more about this?

samhavens avatar May 23 '23 21:05 samhavens

Are there any updates on this? thanks

sasaadi avatar May 26 '23 07:05 sasaadi

except that it's applied in a somewhat nonstandard way in the fwd pass of the transformer module

@tginart Can you say more about this?

https://github.com/mosaicml/llm-foundry/blob/86864e90e0063651177837e831fe48e80618b969/llmfoundry/models/mpt/modeling_mpt.py#LL485C1-L487C1

@samhavens ^ in these lines, mpt uses F.linear directly with the layer weights instead self.transformer.wte(x) which invokes the Pytorch forward that loralib "hijacks" --- so the de-embedding step (I believe) will always use base embeddings & ignore the LoRA embeddings. Again, not 100% sure but that is my understanding of it.

tginart avatar May 26 '23 18:05 tginart

Is there any updates on this? @abhi-mosaic Thanks ~

benam2 avatar Jun 21 '23 14:06 benam2

Pretty sure this has just been added here https://github.com/mosaicml/llm-foundry/pull/346

alextrott16 avatar Jun 27 '23 22:06 alextrott16

Closing as we have added PEFT/LoRA support with #346

abhi-mosaic avatar Jul 06 '23 00:07 abhi-mosaic