GPT2-Chinese
GPT2-Chinese copied to clipboard
训练train.py时报错
Traceback (most recent call last):
File "train.py", line 236, in <module>
trainer.fit(net)
File "/home/gluo/anaconda3/envs/gan/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 460, in fit
self._run(model)
File "/home/gluo/anaconda3/envs/gan/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 758, in _run
self.dispatch()
File "/home/gluo/anaconda3/envs/gan/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 799, in dispatch
self.accelerator.start_training(self)
File "/home/gluo/anaconda3/envs/gan/lib/python3.6/site-packages/pytorch_lightning/accelerators/accelerator.py", line 96, in start_training
self.training_type_plugin.start_training(trainer)
File "/home/gluo/anaconda3/envs/gan/lib/python3.6/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 144, in start_training
self._results = trainer.run_stage()
File "/home/gluo/anaconda3/envs/gan/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 809, in run_stage
return self.run_train()
File "/home/gluo/anaconda3/envs/gan/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 844, in run_train
self.run_sanity_check(self.lightning_module)
File "/home/gluo/anaconda3/envs/gan/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 1112, in run_sanity_check
self.run_evaluation()
File "/home/gluo/anaconda3/envs/gan/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 967, in run_evaluation
output = self.evaluation_loop.evaluation_step(batch, batch_idx, dataloader_idx)
File "/home/gluo/anaconda3/envs/gan/lib/python3.6/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 174, in evaluation_step
output = self.trainer.accelerator.validation_step(args)
File "/home/gluo/anaconda3/envs/gan/lib/python3.6/site-packages/pytorch_lightning/accelerators/accelerator.py", line 226, in validation_step
return self.training_type_plugin.validation_step(*args)
File "/home/gluo/anaconda3/envs/gan/lib/python3.6/site-packages/pytorch_lightning/plugins/training_type/dp.py", line 101, in validation_step
return self.model(*args, **kwargs)
File "/home/gluo/anaconda3/envs/gan/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/home/gluo/anaconda3/envs/gan/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 166, in forward
return self.module(*inputs[0], **kwargs[0])
File "/home/gluo/anaconda3/envs/gan/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/home/gluo/anaconda3/envs/gan/lib/python3.6/site-packages/pytorch_lightning/overrides/data_parallel.py", line 77, in forward
output = super().forward(*inputs, **kwargs)
File "/home/gluo/anaconda3/envs/gan/lib/python3.6/site-packages/pytorch_lightning/overrides/base.py", line 57, in forward
output = self.module.validation_step(*inputs, **kwargs)
File "train.py", line 117, in validation_step
loss = self.forward(batch["input_ids"], batch["attention_mask"])
KeyError: 'attention_mask'
请问一下这个报错怎么处理呀