Qwen/Qwen1.5-0.5B-Chat微调后无法转换
[2025-04-02 21:35:20,877] [INFO] [real_accelerator.py:239:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2025-04-02 21:35:26,187] [INFO] [real_accelerator.py:239:get_accelerator] Setting ds_accelerator to cuda (auto detect)
Sliding Window Attention is enabled but not implemented for sdpa; unexpected results may be encountered.
Traceback (most recent call last):
File "/root/autodl-tmp/xtuner/xtuner/tools/model_converters/pth_to_hf.py", line 139, in weights_only argument in torch.load from False to True. Re-running torch.load with weights_only set to False will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
(2) Alternatively, to load with weights_only=True please check the recommended steps in the following error message.
WeightsUnpickler error: Unsupported global: GLOBAL mmengine.logging.history_buffer.HistoryBuffer was not an allowed global by default. Please use torch.serialization.add_safe_globals([HistoryBuffer]) or the torch.serialization.safe_globals([HistoryBuffer]) context manager to allowlist this global if you trust this class/function.
Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.
我解决了,需要给这个代码补充weights_only=Fasle,再把optims也改为False即可 def parse_model_states(files, dtype=DEFAULT_DTYPE): zero_model_states = [] for file in files: state_dict = torch.load(file, map_location=device,weights_only=False) #显示允许类
1. 修改模型权重加载
def parse_model_states(files, dtype=DEFAULT_DTYPE): zero_model_states = [] for file in files: # Modify model weight loading state_dict = torch.load(file, map_location='cpu', weights_only=False) # 添加参数 # ...后续处理...
2. 涉及优化器状态加载,同样需修改
@torch.no_grad() def parse_optim_states(files, ds_checkpoint_dir, dtype=DEFAULT_DTYPE): zero_stage = None world_size = None total_files = len(files) flat_groups = [] torch.serialization.add_safe_globals([ConfigDict]) for f in tqdm(files, desc="Load Checkpoints"): state_dict = torch.load(f, map_location=device, weights_only=False) # 添加参数 # ...后续处理...