multigen icon indicating copy to clipboard operation
multigen copied to clipboard

NotImplementedError

Open xiang-xiang-zhu opened this issue 3 years ago • 1 comments

when i run inference code python3 main.py \ --train_data_file ${ROOT_PATH}/data/${DATA_TYPE}/train \ --dev_data_file ${ROOT_PATH}/data/${DATA_TYPE}/dev \ --test_data_file ${ROOT_PATH}/data/${DATA_TYPE}/test \ --graph_path 2hops_100_directed_triple_filter.json \ --output_dir ${ROOT_PATH}/models/${DATA_TYPE}/grf-${DATA_TYPE} \ --source_length 32 \ --target_length 32 \ --model_type gpt2 \ --model_name_or_path ${ROOT_PATH}/models/gpt2-small \ --do_eval \ --per_gpu_train_batch_size 16 \ --per_gpu_eval_batch_size 16 \ --workers 7 \ --seed 42 \ --evaluate_metrics bleu \ --overwrite_output_dir \ --aggregate_method max \ --gamma 0.5 \

Error [04/09/2022 21:49:57 - WARNING - root] Process rank: -1, device: cuda, n_gpu: 1, distributed training: False, 16-bits training: False Setting 'max_len_single_sentence' is now deprecated. This value is automatically set up. Setting 'max_len_sentences_pair' is now deprecated. This value is automatically set up. source length: 32 [04/09/2022 21:49:58 - INFO - modeling_gpt2] Tie weights in head!!!!! [04/09/2022 21:49:58 - INFO - modeling_gpt2] Tie weights in head!!!!! Traceback (most recent call last): File "main.py", line 620, in main() File "main.py", line 588, in main model.resize_token_embeddings(len(tokenizer)) File "/home/hx/anaconda3/envs/torch1.4/lib/python3.7/site-packages/transformers/modeling_utils.py", line 724, in resize_token_embeddings model_embeds = self._resize_token_embeddings(new_num_tokens) File "/home/hx/anaconda3/envs/torch1.4/lib/python3.7/site-packages/transformers/modeling_utils.py", line 738, in _resize_token_embeddings old_embeddings = self.get_input_embeddings() File "/home/hx/anaconda3/envs/torch1.4/lib/python3.7/site-packages/transformers/modeling_utils.py", line 563, in get_input_embeddings return base_model.get_input_embeddings() File "/home/hx/anaconda3/envs/torch1.4/lib/python3.7/site-packages/transformers/modeling_utils.py", line 565, in get_input_embeddings raise NotImplementedError NotImplementedError

xiang-xiang-zhu avatar Apr 09 '22 13:04 xiang-xiang-zhu

Add these two functions to the GPT2Model class in modeling_gpt2.py:

def get_input_embeddings(self): return self.wte

def set_input_embeddings(self, new_embeddings): self.wte = new_embeddings

MoemenGaafar avatar Nov 25 '22 05:11 MoemenGaafar