SeqGenSQL
SeqGenSQL copied to clipboard
requirements
Hi,
I have been trying to reproduce the results in this repo. But it seems I am having some version issue for some packages, like pytorch_lightning. Is it possible to add a requirement file about all the packages that work together?
Thank you very much!
We had the same issue. We cannot replicate it as we are running against several issues of pytorch and do not know which versions were used.
We made it. pytorch_lightning==0.10.0 and torch==1.7.1 (in our case because of CUDA 10.1 restraints) (also install sentencepiece)
In train.py you have to delete the line parser.add_argument("--early_stop_callback", default=False) because this argument has been put twice.
Then in module.py change line 96 so that the parameter is called labels instead of lm_labels thus: labels=lm_labels,
In line 79 of train.py, the author mentioned pytorch_lightning 0.8.4. With torch==1.7.1, we have another working combination.
I have also rename lm_labels to labels in module.py.
But now, I have two problems. 1. After training for 25 epochs, the success rate on dev dataset is still below 10%. 2. When I load the trained ckpt (from https://onebigdatabag.blob.core.windows.net/shared/base_gated_e09_0.02626.ckpt), the following error pops out which seems like a pytorch_lightning issue.
`---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
~/.local/lib/python3.8/site-packages/pytorch_lightning/core/saving.py in load_from_checkpoint(cls, checkpoint_path, map_location, hparams_file, tags_csv, *args, **kwargs) 167 checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs) 168 --> 169 model = cls._load_model_state(checkpoint, *args, **kwargs) 170 return model 171
~/.local/lib/python3.8/site-packages/pytorch_lightning/core/saving.py in _load_model_state(cls, checkpoint, *cls_args, **cls_kwargs) 205 model = cls(*cls_args, **cls_kwargs) 206 # load the state_dict on the model automatically --> 207 model.load_state_dict(checkpoint['state_dict']) 208 209 # give model a chance to load something
/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict) 1049 1050 if len(error_msgs) > 0: -> 1051 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( 1052 self.class.name, "\n\t".join(error_msgs))) 1053 return _IncompatibleKeys(missing_keys, unexpected_keys)
RuntimeError: Error(s) in loading state_dict for SeqGenSQL: Unexpected key(s) in state_dict: "model.decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight". `