运行app.py时会删除下载好的与训练文件并重新下载
为什么我在hugging face上面手动下载好的checkpoint每次都会在运行app.py时都会被删除,并且又再次在线下载,结果就是导致在线下载文件太大,网络不稳定,下载不成功。有没有什么解决方法,感谢!
app.py在最初会运行以下函数进行下载: https://github.com/CyberAgentAILab/TANGO/blob/e24a2f4c64f4e9addb57aad1ebb8cdaf2eeada6c/app.py#L30
检测标准是最后一个所需文件是否下载完成,所以我猜你是下载了部分但不是全部的checkpoint导致每次重新下载。 https://github.com/CyberAgentAILab/TANGO/blob/e24a2f4c64f4e9addb57aad1ebb8cdaf2eeada6c/utils/download_utils.py#L9
如果你确定已经手动下载到所有需要的checkpoint,可以直接注释掉app.py里的这一行跳过下载过程。
原来的下载代码不是很好,网络不好成功率低,可以适当修改TANGO/utils/download_utils.py代码(可以先备份一下),以便支持续传和多线程高速下载,具体代码如下: import os from huggingface_hub import snapshot_download from concurrent.futures import ThreadPoolExecutor
def download_files_from_repo(): # 国内可设置国内镜像地址 #os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
# check the last ckpts are downloaded
repo_id = "H-Liu1997/TANGO"
local_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../")
last_ckpt_path = os.path.join(local_dir, "SMPLer-X/pretrained_models/smpler_x_s32.pth.tar")
if os.path.exists(last_ckpt_path):
return
# 定义需要下载的文件模式
patterns = [
("frame-interpolation-pytorch/*.pt", local_dir),
("Wav2Lip/checkpoints/*.pth", local_dir),
("datasets/cached_ckpts/*", local_dir),
("datasets/cached_graph/*", local_dir),
("emage/smplx_models/smplx/*", local_dir),
("SMPLer-X", os.path.join(local_dir, "./SMPLer-X"), "caizhongang/SMPLer-X", "pretrained_models/*"),
("SMPLer-X/pretrained_models/*", local_dir)
]
def download_pattern(pattern, local_dir, repo_id=repo_id):
if len(pattern) == 4:
allow_patterns, local_dir, repo_id, ignore_patterns = pattern
snapshot_download(
repo_id=repo_id,
local_dir=local_dir,
repo_type="space",
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
force_download=True,
resume_download=True
)
else:
allow_patterns, local_dir = pattern
snapshot_download(
repo_id=repo_id,
local_dir=local_dir,
repo_type="space",
allow_patterns=allow_patterns,
force_download=True,
resume_download=True
)
# 使用多线程并行下载
with ThreadPoolExecutor(max_workers=5) as executor:
executor.map(lambda p: download_pattern(p, local_dir), patterns)
print("Downloaded all the necessary files from the repo.")
''' 主要改进:
- 设置国内镜像地址:在函数开始时设置 os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com',以指定国内镜像地址。
- 断点续传:在每个 snapshot_download 调用中添加了 resume_download=True 参数,以支持断点续传。
- 多线程下载:使用 ThreadPoolExecutor 来并行下载不同的文件模式,提高下载速度。 这样,代码将使用国内镜像地址进行下载,并且支持断点续传和多线程下载,以提高下载效率。 '''