peft icon indicating copy to clipboard operation
peft copied to clipboard

Peft is much slower than the origin models when doing inference.

Open SefaZeng opened this issue 3 years ago • 7 comments

I try to finetune the bloomz-1b7 model for translation and using peft lora. And the fine-tuned model without lora is twice as fast as the one with lora. I use the TextGenerationPipeline to generate the results.

SefaZeng avatar Apr 13 '23 02:04 SefaZeng

Hi @SefaZeng You should consider merging the lora layers and run the merged model as a standalone transformers model

model = model.merge_and_unload()

Related: https://github.com/huggingface/peft/issues/217#issuecomment-1506224612

younesbelkada avatar Apr 13 '23 08:04 younesbelkada

Hi @SefaZeng You should consider merging the lora layers and run the merged model as a standalone transformers model

model = model.merge_and_unload()

Related: #217 (comment)

Thx for your reply! Do I need to do it during training or just during inference?

SefaZeng avatar Apr 13 '23 09:04 SefaZeng

Thanks! This would be relevant for inference only, for training, I am currently not aware of optimization tricks, we might explore that soon

younesbelkada avatar Apr 13 '23 09:04 younesbelkada

Hello @SefaZeng, can you provide minimal script for us to deep dive, are you comparing fine-tuned fp16/bf16 against int8+peft? Want to make sure that the only diff is peft and nothing else

pacman100 avatar Apr 13 '23 11:04 pacman100

Hello @SefaZeng, can you provide minimal script for us to deep dive, are you comparing fine-tuned fp16/bf16 against int8+peft? Want to make sure that the only diff is peft and nothing else

The whole script is a bit complicated. I just run the run_clm.py from transformers examples and add these lines to it:

    if model_args.lora:
        logger.info(f"Using lora fine-tune")
        peft_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1
        )
        model = get_peft_model(model, peft_config)
        model.print_trainable_parameters()

The first thing I hit is that the model saved under output_path is very small and looks like it only has the new parameters for peft_model. And there is no adapter_config.json, so when I load the model for inference it reports the following error:

Traceback (most recent call last):
  File "/workspace/inference/test_new.py", line 132, in <module>
    config = PeftConfig.from_pretrained(modelpath)
  File "/opt/anaconda3/lib/python3.8/site-packages/peft/utils/config.py", line 101, in from_pretrained
    raise ValueError(f"Can't find config.json at '{pretrained_model_name_or_path}'")
ValueError: Can't find config.json at '/workspace/bloom_1b7_ft_ja2zh_lora/checkpoint-12000'

Then I create an adapter_config.json manually and it seems to work well. The json file is:

{                                                                                                                                                                                                                                                    
  "base_model_name_or_path": "/workspace/pretrain_models/bloomz-1b7",
  "bias": "none",
  "enable_lora": [
    true,
    false,
    true
  ],
  "fan_in_fan_out": true,
  "inference_mode": true,
  "lora_alpha": 32, 
  "lora_dropout": 0.0,
  "merge_weights": false,
  "modules_to_save": null,
  "peft_type": "LORA",
  "r": 8,
  "target_modules": [
    "query_key_value"
  ],
  "task_type": "CAUSAL_LM"
}

But when I try to inference, the model seems very slow. The main body for inference is like this:

if __name__ == '__main__':
    arg_parser = init_opt()
    args = arg_parser.parse_args()

    modelpath=args.model_path
    infile=args.input_file
    outfile=args.output_file
    prompt_id = args.prompt_id
    batch_size = args.batch_size
    num_beams = args.beam_size

    prompts = load_prompts(args.prompt_file, args.task)
    logger.info(f">> Translate file: {infile}")

    if "bloom" in modelpath:
        tokenizer = AutoTokenizer.from_pretrained(modelpath)
    else:
        tokenizer = MT5Tokenizer.from_pretrained(modelpath, padding_side='left')


    if "lora" in modelpath.lower():
        logger.info(f">> Load Lora model from {modelpath}")
        from peft import PeftModel, PeftConfig
        #config = PeftConfig.from_pretrained(modelpath)
        model = AutoModelForCausalLM.from_pretrained(args.origin_model_path)
        model = PeftModel.from_pretrained(model, modelpath)
        model = model.merge_and_unload()
    elif "bloom" in modelpath:
        logger.info(f">> Load Bloom model from {modelpath}")
        from transformers import BloomForCausalLM
        model = BloomForCausalLM.from_pretrained(modelpath)
    else:
        model = GPT2LMHeadModel.from_pretrained(modelpath)
    model.cuda()

    pipeline = TextGenerationPipeline(model=model, batch_size=batch_size, tokenizer=tokenizer,
                                      return_full_text=False, device="cuda:0", #model.device,
                                      clean_up_tokenization_spaces=True,
                                      handle_long_generation="hole")

    #pipeline.tokenizer.pad_token_id = model.config.eos_token_id

    inf = open(infile, "r", encoding='utf-8')

    sentences = []
    lengths = []
    logger.info(f">> prompts: {prompts}")
    logger.info(f">> prompt_id: {prompt_id}")
    for line in inf:
        line = line.strip()
        new_text = encode_prompt(prompts, line, "one", prompt_id=prompt_id)
        sentences.append(new_text[:])

    logger.info(f">> start generating")
    
    trans_start_time = time.time()
    hypothesis = pipeline(sentences, do_sample=False, num_beams=num_beams, max_length=512, pad_token_id=tokenizer.pad_token_id) #[0]["generated_text"]
    trans_end_time = time.time()

    outf = open(outfile, "w", encoding='utf-8')

    for sent in hypothesis:
        hyp = sent[0]['generated_text']
        hyp = hyp.split("\n")[0]
        outf.write(hyp[:] + "\n")
    logger.info(f">> Translate done. Time used for translate: {trans_end_time - trans_start_time} s")
    inf.close()
    outf.close()

BTW, there are also some warnings that say PeftModel is not supported for TextGenerationPipeline. I just dismissed it as I found Pipeline is much faster than using a tokenizer by myself to group batch inputs. Still do not know the reason. :-(

SefaZeng avatar Apr 14 '23 07:04 SefaZeng

Now I install peft from the master branch to use

model = model.merge_and_unload()

But it fail to load my peftmodel as the arch is not the same like:

Traceback (most recent call last):
  File "/workspace/inference/test_new.py", line 135, in <module>
    model = PeftModel.from_pretrained(model, modelpath)
  File "/opt/anaconda3/lib/python3.8/site-packages/peft/peft_model.py", line 172, in from_pretrained
    model.load_adapter(model_id, adapter_name, **kwargs)
  File "/opt/anaconda3/lib/python3.8/site-packages/peft/peft_model.py", line 361, in load_adapter
    set_peft_model_state_dict(self, adapters_weights, adapter_name=adapter_name)
  File "/opt/anaconda3/lib/python3.8/site-packages/peft/utils/save_and_load.py", line 120, in set_peft_model_state_dict
    model.load_state_dict(peft_model_state_dict, strict=False)
  File "/opt/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1482, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:
    {}'.format(
RuntimeError: Error(s) in loading state_dict for PeftModelForCausalLM:
	size mismatch for base_model.model.transformer.h.0.self_attention.query_key_value.lora_A.default.weight: copying a param with shape torch.Size([16, 2048]) from checkpoint, the shape in current model is torch.Size([8, 2048]).
	size mismatch for base_model.model.transformer.h.0.self_attention.query_key_value.lora_B.default.weight: copying a param with shape torch.Size([4096, 8, 1]) from checkpoint, the shape in current model is torch.Size([6144, 8]).
	size mismatch for base_model.model.transformer.h.1.self_attention.query_key_value.lora_A.default.weight: copying a param with shape torch.Size([16, 2048]) from checkpoint, the shape in current model is torch.Size([8, 2048]).
	size mismatch for base_model.model.transformer.h.1.self_attention.query_key_value.lora_B.default.weight: copying a param with shape torch.Size([4096, 8, 1]) from checkpoint, the shape in current model is torch.Size([6144, 8]).
	size mismatch for base_model.model.transformer.h.2.self_attention.query_key_value.lora_A.default.weight: copying a param with shape torch.Size([16, 2048]) from checkpoint, the shape in current model is torch.Size([8, 2048]).
	size mismatch for base_model.model.transformer.h.2.self_attention.query_key_value.lora_B.default.weight: copying a param with shape torch.Size([4096, 8, 1]) from checkpoint, the shape in current model is torch.Size([6144, 8]).
	size mismatch for base_model.model.transformer.h.3.self_attention.query_key_value.lora_A.default.weight: copying a param with shape torch.Size([16, 2048]) from checkpoint, the shape in current model is torch.Size([8, 2048]).
...

Still debugging ...

SefaZeng avatar Apr 14 '23 07:04 SefaZeng

Now I install peft from the master branch to use

model = model.merge_and_unload() But it fail to load my peftmodel as the arch is not the same like:

Please refer https://github.com/huggingface/peft/issues/276#issuecomment-1500265524

pacman100 avatar Apr 14 '23 11:04 pacman100

I found the reason for the slower inference speed is that I finetune the Bloomz model for machine translation for Japanese and Chinese. Using Lora will generate some repeat tokens during generation like Today is a nice day day day day day day day day day day day...... That makes the generation time much longer. I am unsure if further training could solve this, but adding no_repeat_ngram_size=3 could avoid this.

SefaZeng avatar May 05 '23 08:05 SefaZeng

So, may I conclude the following statements:

  1. It would be better to NOT perform generation during training. Even though our standard practice would be evaluating the model (on the validation set) after each epoch. Because this is going to be very slow.
  2. Inference will be done after the training is done. Am I right?

allanj avatar Oct 17 '23 06:10 allanj