oneflow
oneflow copied to clipboard
Add error msg for local/global mismatch when loading state_dict
背景:https://github.com/Oneflow-Inc/OneCloud/issues/70#issuecomment-1077397584
概述:module 的 load_state_dict
报错信息不太友好,并且没有检查 checkpoint 和 model 参数的 is_global
是否匹配,导致用户的 global model 加载 local state_dict 时,报错会被后面的 try except 捕捉到,返回一些无法提炼重要信息的错误内容
实现:在 module.py 的 load_state_dict
遍历加载参数时,增加了两者间 is_global
是否匹配的检查,防止在后面的 try except 报错,并且增加了对应的帮助信息,告诉用户如何更改代码来防止错误
一些复现代码:
import oneflow as flow
model = flow.nn.Sequential(
flow.nn.Linear(3, 3),
flow.nn.Linear(3, 3),
)
import tempfile
with tempfile.TemporaryDirectory() as tmpdirname:
flow.save(model.state_dict(), tmpdirname)
model = flow.nn.Sequential(
flow.nn.Linear(3, 3),
flow.nn.Linear(3, 3),
).to_global(sbp=flow.sbp.broadcast, placement=flow.env.all_device_placement("cuda"))
model.load_state_dict(flow.load(tmpdirname,))