CoLLiE
CoLLiE copied to clipboard
Error: llama2 70B LlamaForCausalLM.from_pretrained 开启Zero3,会消耗大量内存导致 OOM
8张 V100 显卡,开启 Zero3,TP=1,PP=1,DP=8,LlamaForCausalLM.from_pretrained llama 70B 模型会出现 OOM (内存不够,不是显存不够),物理内存 512GB。
原因是 dev 分支中,base.py 304行,
state_dict = {}
if not is_zero3_enabled(config) or env.dp_rank == 0
or config.low_cpu_mem_usage or config.quantization_config.load_in_8bit
or getattr(config.quantization_config, "load_in_4bit", False):
state_dict = cls.load_parallel_state_dict(
path=model_path_or_name, config=config,
process_exclusion=process_exclusion, **kwargs
)
会导致 8 个进程 都 加载一次 state_dict,内存消耗很大,导致OOM