关于bge-multilingual-gemma2的lora微调的显存问题
作者你好,我在微调bge-multilingual-gemma2模型的时候,对显存的使用觉得比较奇怪。按道理来说,lora微调一个9b的模型显存的使用应该是20到40g左右,但是我在batch_size等都设置为1的时候,还是需要60多g,是不是我的设置有哪些地方不太合理。
配置信息如下:
torchrun --nproc_per_node 1
-m FlagEmbedding.finetune.reranker.decoder_only.base
--model_name_or_path /root/autodl-tmp/model
--use_lora True
--lora_rank 32
--lora_alpha 64
--use_flash_attn False
--target_modules q_proj k_proj v_proj o_proj
--save_merged_lora_model True
--model_type decoder
--cache_dir ./cache/model
--train_data /root/autodl-tmp/output.json
--cache_path ./cache/data
--train_group_size 1
--query_max_len 512
--passage_max_len 512
--pad_to_multiple_of 8
--knowledge_distillation False
--query_instruction_for_rerank 'A: '
--query_instruction_format '{}{}'
--passage_instruction_for_rerank 'B: '
--passage_instruction_format '{}{}'
--output_dir ./test_decoder_only_base_bge-reranker-v2-minicpm-layerwise
--overwrite_output_dir
--learning_rate 1e-5
--bf16
--num_train_epochs 2
--per_device_train_batch_size 1
--per_device_eval_batch_size 1
--gradient_accumulation_steps 1
--dataloader_drop_last True
--warmup_ratio 0.1
--weight_decay 0.01
--logging_steps 1
--save_steps 1000
bge-multilingual-gemma2是embedder模型,需要用embedder的微调代码
@545999961 非常感谢你的提醒,我在使用对应的代码微调时,对lora的显存使用还是觉得有点疑惑。在设置了lora微调时,这里显存的使用像是全量微调,而不是lora。
torchrun --nproc_per_node 1
-m FlagEmbedding.finetune.embedder.decoder_only.base
--model_name_or_path /root/autodl-tmp/bge-multilingual-gemma2
--cache_dir ./cache/model
--use_lora True
--lora_rank 32
--lora_alpha 64
--target_modules q_proj k_proj v_proj o_proj gate_proj down_proj up_proj
--additional_special_tokens '
--save_merged_lora_model True
--train_data /root/autodl-tmp/10neg_New_output.json
--cache_path ./cache/data
--train_group_size 1
--query_max_len 512
--passage_max_len 512
--pad_to_multiple_of 8
--query_instruction_for_retrieval 'Given a query, retrieve passages that are relevant to the query.'
--query_instruction_format '
--knowledge_distillation True
--same_dataset_within_batch True
--small_threshold 0
--drop_threshold 0
--output_dir ./test_decoder_only_base_bge-multilingual-gemma2_sd
--overwrite_output_dir
--learning_rate 1e-4
--bf16 True
--num_train_epochs 1
--per_device_train_batch_size 1
--dataloader_drop_last True
--warmup_ratio 0.1
--logging_steps 1
--save_steps 1000
--negatives_cross_device
--temperature 0.02
--sentence_pooling_method last_token
--normalize_embeddings True \
在embedding任务微淘时,大部分显存开销来源于forward传播,这部分是lora无法避免的 在embedding任务上使用lora可以获得更好的效果,参考 repllama.