LiZijunApril

Results 1 comments of LiZijunApril

下载到本地的模型也可以,要修改一下代码。 这个错误的直接原因是Mamba类里from_pretrained方法中的load_state_dict_hf函数里,使用的cached_file(model_name, WEIGHTS_NAME),这里的WEIGHTS_NAME通常是'pytorch_model.bin',表示保存和加载模型权重的默认文件名。但是下载到本地的模型提供的权重文件是'model.safetensors',所以要用safetensors库中的torch.load_file加载: ![code](https://github.com/johnma2006/mamba-minimal/assets/46208049/1a25cb1e-cecc-403a-ad72-b00c9cc14182) 加载完了之后,不知道为什么权重跟结构对不上,要稍微修改一下,在return model之前: ![code](https://github.com/johnma2006/mamba-minimal/assets/46208049/2a43e2a3-5723-483d-83c7-bc479b6c5ec9)