InternVL icon indicating copy to clipboard operation
InternVL copied to clipboard

[Bug] Cannot load trained adapter correctly into the InternVLChatModel model

Open 14H034160212 opened this issue 1 year ago • 0 comments

Checklist

  • [X] 1. I have searched related issues but cannot get the expected help.
  • [X] 2. The bug has not been fixed in the latest version.
  • [X] 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.

Describe the bug

The bug is I can save the adapter model but after I load the trained adapter weight and do continual training, the loss is not continual going down. The loss curve looks like the training is restarting instead of continuing from the first training.

Reproduction

Here is code about how I add adapter and save the adapter.

if model_args.use_llm_lora:
        model.wrap_llm_lora(r=model_args.use_llm_lora, lora_alpha=2 * model_args.use_llm_lora)
        model.config.use_llm_lora = model_args.use_llm_lora
        if model_args.train_llm_adapter and model_args.llm_adapter_name is None:
            lora_config = LoraConfig(
                lora_alpha=2 * model_args.use_llm_lora,
                r=model_args.use_llm_lora,
                base_model_name_or_path=tokenizer_path,
                target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
                task_type="TOKEN_CLS",
            )
            model.language_model = get_peft_model(model.language_model, lora_config)
            model.add_adapter(lora_config, adapter_name=training_args.output_dir.split("/")[-1])
            model.set_adapter(training_args.output_dir.split("/")[-1])

class SavePeftModelCallback(TrainerCallback):
        def on_save(
            self,
            args: TrainingArguments,
            state: TrainerState,
            control: TrainerControl,
            **kwargs,
        ):
            checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")

            peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
            kwargs["model"].save_pretrained(peft_model_path)

            pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
            if os.path.exists(pytorch_model_path):
                os.remove(pytorch_model_path)
            return control

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset if training_args.do_train else None,
        eval_dataset=eval_dataset if training_args.do_eval else None,
        tokenizer=tokenizer,
        data_collator=concat_pad_data_collator,
        callbacks=[SavePeftModelCallback],
    )

Here is the code about how I load the adapter and do continue training.

if model_args.model_name_or_path is not None:
        if model_args.train_llm_adapter and model_args.llm_adapter_name is not None:
            logger.info("Loading adapter for InternVLChatModel...")
            lora_config = LoraConfig.from_pretrained(model_args.llm_adapter_name)
            config = InternVLChatConfig.from_pretrained(model_args.model_name_or_path)
            model = InternVLChatModel.from_pretrained(
                model_args.model_name_or_path, torch_dtype=torch.bfloat16, config=config
            )
            model.language_model = PeftModel.from_pretrained(model.language_model,
                model_args.llm_adapter_name,
                is_trainable=True, #
                adapter_name=model_args.llm_adapter_name.split('/')[-2]
                # adapter_name="InternVL_LLM"
            )
            model.language_model.enable_input_require_grads()

Environment

--model_name_or_path "OpenGVLab/InternVL2-1B" \
--freeze_llm True \
--freeze_mlp True \
--use_llm_lora 16 \

Error traceback

No response

14H034160212 avatar Oct 02 '24 22:10 14H034160212