pytorch-bert-crf-ner icon indicating copy to clipboard operation
pytorch-bert-crf-ner copied to clipboard

학습 모델 로드시 질문

Open robinsongh381 opened this issue 5 years ago • 1 comments

안녕하세요

좋은 자료 공유해주셔서 감사의 말씀을 우선 정합니다.

Inference.py 에서

    convert_keys = {}
    for k, v in checkpoint['model_state_dict'].items():
        new_key_name = k.replace("module.", '')
        if new_key_name not in model_dict:
            print("{} is not int model_dict".format(new_key_name))
            continue
        convert_keys[new_key_name] = v

다음과 같이 convert_keys 를 정의하고 model.load_state_dict(convert_keys)를 하셨는데, 왜 바로 model.load_state_dict(checkpoint['model_state_dict']) 하시지 않았는지 아니면 하면 안되는지 궁금하여 질문을 드립니다

감사합니다

robinsongh381 avatar Jan 14 '20 09:01 robinsongh381

아마 모델을 분산학습 시키셔서 모든 weight들의 이름이 module.~ 이런식으로 되어 있을 겁니다. 따라서 그냥 model.load_state_dict(checkpoint['model_state_dict'])를 하면 key error가 발생합니다.

dave-rtzr avatar Dec 07 '20 02:12 dave-rtzr