fairseq icon indicating copy to clipboard operation
fairseq copied to clipboard

Data2Vec: error when loading own pretrained model

Open stefan-it opened this issue 1 year ago • 3 comments

Hi,

I'm currently training an own Data2Vec text model with the latest fairseq master version. When I'm loading the model checkpoint with:

cd examples/data2vec/models

then:

from data2vec_text import Data2VecTextModel

data2vec = Data2VecTextModel.from_pretrained("/mnt/data2vec-base-turkish-cased", "checkpoint_2_100000.pt")

# or:
# data2vec = Data2VecTextModel.from_pretrained(model_name_or_path="/mnt/data2vec-base-turkish-cased", checkpoint_file="checkpoint_2_100000.pt")

the following error message is thrown:

AttributeError                            Traceback (most recent call last)
<ipython-input-8-d955c1d25087> in <module>
----> 1 data2vec = Data2VecTextModel.from_pretrained(model_name_or_path="/mnt/data2vec-base-turkish-cased", checkpoint_file="checkpoint_2_100000.pt")

/mnt/turkish-data2vec/fairseq/fairseq/models/fairseq_model.py in from_pretrained(cls, model_name_or_path, checkpoint_file, data_name_or_path, **kwargs)
    265         from fairseq import hub_utils
    266 
--> 267         x = hub_utils.from_pretrained(
    268             model_name_or_path,
    269             checkpoint_file,

/mnt/turkish-data2vec/fairseq/fairseq/hub_utils.py in from_pretrained(model_name_or_path, checkpoint_file, data_name_or_path, archive_map, **kwargs)
     71         utils.import_user_module(argparse.Namespace(user_dir=kwargs["user_dir"]))
     72 
---> 73     models, args, task = checkpoint_utils.load_model_ensemble_and_task(
     74         [os.path.join(model_path, cpt) for cpt in checkpoint_file.split(os.pathsep)],
     75         arg_overrides=kwargs,

/mnt/turkish-data2vec/fairseq/fairseq/checkpoint_utils.py in load_model_ensemble_and_task(filenames, arg_overrides, task, strict, suffix, num_shards, state)
    480                 ):
    481                     model.set_num_updates(state["optimizer_history"][-1]["num_updates"])
--> 482                 model.load_state_dict(
    483                     state["model"], strict=strict, model_cfg=cfg.model
    484                 )

/mnt/turkish-data2vec/fairseq/fairseq/models/fairseq_model.py in load_state_dict(self, state_dict, strict, model_cfg, args)
    121             model_cfg = convert_namespace_to_omegaconf(args).model
    122 
--> 123         self.upgrade_state_dict(state_dict)
    124 
    125         from fairseq.checkpoint_utils import prune_state_dict

/mnt/turkish-data2vec/fairseq/fairseq/models/fairseq_model.py in upgrade_state_dict(self, state_dict)
    130     def upgrade_state_dict(self, state_dict):
    131         """Upgrade old state dicts to work with newer code."""
--> 132         self.upgrade_state_dict_named(state_dict, "")
    133 
    134     def upgrade_state_dict_named(self, state_dict, name):

/mnt/turkish-data2vec/fairseq/examples/data2vec/models/data2vec_text.py in upgrade_state_dict_named(self, state_dict, name)
    255             self.encoder.lm_head = None
    256 
--> 257         if self.encoder.target_model is None:
    258             for k in list(state_dict.keys()):
    259                 if k.startswith(prefix + "encoder.target_model."):

/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py in __getattr__(self, name)
   1129             if name in modules:
   1130                 return modules[name]
-> 1131         raise AttributeError("'{}' object has no attribute '{}'".format(
   1132             type(self).__name__, name))
   1133 

AttributeError: 'Data2VecTextEncoder' object has no attribute 'target_model'

Notice: It is working with the released checkpoint nlp_base.pt.

stefan-it avatar Jul 04 '22 14:07 stefan-it

@alexeib do you have any hint of solving it? I would really like to convert that checkpoint with Transformers library and test it on downstream tasks. Any help is highly appreciated :hugs:

stefan-it avatar Jul 04 '22 14:07 stefan-it

I did some debugging and one minor fix would be:

if hasattr(self.encoder, "target_model") and self.encoder.target_model is None

instead of:

if self.encoder.target_model is None:

What do you think? I could prepare a PR for that!

stefan-it avatar Jul 04 '22 17:07 stefan-it

i think this looks reasonable, please submit a PR!

alexeib avatar Jul 04 '22 18:07 alexeib