How to only save adapter_model when training with Lora. [Feature]
Motivation
Thank you for your work. I trained Internvl 70b with Lora on LLM. But I found every times it saved all parameters instead of the adapter_model. How to only save adapter_model.
Related resources
No response
Additional context
No response
same question...it requires a large amount of disk space
You can add this flag to make sure it will not save the backbone model from deepspeed.
--save_only_model
I am still working in progress about loading the trained adapter back to the model. It seems not working. What I did is I was adding a adapter to the LLM only. If anyone can provide some suggestions that will be great.
This is how I load the trained adapter model
logger.info("Loading adapter for InternVLChatModel...")
lora_config = LoraConfig.from_pretrained(model_args.llm_adapter_name)
config = InternVLChatConfig.from_pretrained(model_args.model_name_or_path)
model = InternVLChatModel.from_pretrained(
model_args.model_name_or_path, torch_dtype=torch.bfloat16, config=config
)
model.language_model = PeftModel.from_pretrained(model.language_model,
model_args.llm_adapter_name,
is_trainable=True, #
adapter_name=model_args.llm_adapter_name.split('/')[-2]
# adapter_name="InternVL_LLM"
)
model.language_model.enable_input_require_grads()
This is how I saved the adapter model
if model_args.use_llm_lora:
model.wrap_llm_lora(r=model_args.use_llm_lora, lora_alpha=2 * model_args.use_llm_lora)
model.config.use_llm_lora = model_args.use_llm_lora
if model_args.train_llm_adapter and model_args.llm_adapter_name is None:
lora_config = LoraConfig(
lora_alpha=2 * model_args.use_llm_lora,
r=model_args.use_llm_lora,
base_model_name_or_path=tokenizer_path,
target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
task_type="TOKEN_CLS",
)
model.language_model = get_peft_model(model.language_model, lora_config)
model.add_adapter(lora_config, adapter_name=training_args.output_dir.split("/")[-1])
model.set_adapter(training_args.output_dir.split("/")[-1])