oneflow icon indicating copy to clipboard operation
oneflow copied to clipboard

Add error msg for local/global mismatch when loading state_dict

Open marigoold opened this issue 2 years ago • 0 comments

背景:https://github.com/Oneflow-Inc/OneCloud/issues/70#issuecomment-1077397584 概述:module 的 load_state_dict 报错信息不太友好,并且没有检查 checkpoint 和 model 参数的 is_global 是否匹配,导致用户的 global model 加载 local state_dict 时,报错会被后面的 try except 捕捉到,返回一些无法提炼重要信息的错误内容

image

实现:在 module.py 的 load_state_dict 遍历加载参数时,增加了两者间 is_global 是否匹配的检查,防止在后面的 try except 报错,并且增加了对应的帮助信息,告诉用户如何更改代码来防止错误 image

一些复现代码:

import oneflow as flow
model = flow.nn.Sequential(
    flow.nn.Linear(3, 3),
    flow.nn.Linear(3, 3),
)
import tempfile
with tempfile.TemporaryDirectory() as tmpdirname:
    flow.save(model.state_dict(), tmpdirname)
    model = flow.nn.Sequential(
        flow.nn.Linear(3, 3),
        flow.nn.Linear(3, 3),
    ).to_global(sbp=flow.sbp.broadcast, placement=flow.env.all_device_placement("cuda"))
    model.load_state_dict(flow.load(tmpdirname,))

marigoold avatar Aug 10 '22 02:08 marigoold