TradeTheEvent icon indicating copy to clipboard operation
TradeTheEvent copied to clipboard

I got the error when I run run_event.py

Open lauht opened this issue 3 years ago • 1 comments

Traceback (most recent call last): File "run_event.py", line 489, in main() File "run_event.py", line 463, in main outputs = model(input_ids, attention_mask=attention_mask, seq_labels=seq_labels, ner_labels=ner_labels) File "/data1/lht/anaconda3/envs/tte/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/data1/lht/anaconda3/envs/tte/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 167, in forward outputs = self.parallel_apply(replicas, inputs, kwargs) File "/data1/lht/anaconda3/envs/tte/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 177, in parallel_apply return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) File "/data1/lht/anaconda3/envs/tte/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 86, in parallel_apply output.reraise() File "/data1/lht/anaconda3/envs/tte/lib/python3.7/site-packages/torch/_utils.py", line 429, in reraise raise self.exc_type(msg) RuntimeError: Caught RuntimeError in replica 0 on device 0. Original Traceback (most recent call last): File "/data1/lht/anaconda3/envs/tte/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker output = module(*input, **kwargs) File "/data1/lht/anaconda3/envs/tte/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/data1/lht/TradeTheEvent/utils/model.py", line 170, in forward seq_logits = self.final_classifier1(seq_logits) File "/data1/lht/anaconda3/envs/tte/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/data1/lht/anaconda3/envs/tte/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 94, in forward return F.linear(input, self.weight, self.bias) File "/data1/lht/anaconda3/envs/tte/lib/python3.7/site-packages/torch/nn/functional.py", line 1753, in linear return torch._C._nn.linear(input, weight, bias) RuntimeError: mat1 dim 1 must match mat2 dim 0

Is the error caused by networks, or the version of torch is not valid?

lauht avatar Jan 02 '22 14:01 lauht

run_event.py seems untested.
Amongst other issues, the config.num_labels is not fixed according to the argparser but to BERT's default config. Insert the following code in run_event.py roughly around line 397:

    logger.info(
        'Total training batch size: {}'.format(args.per_gpu_batch_size * args.gradient_accumulation_steps * args.n_gpu))

    config = BertConfig.from_pretrained(args.model_type)
    # config.num_labels = 12                                     
    config.num_labels = args.num_labels                   # insert this line        
    config.max_seq_length = args.max_seq_length           # insert this line
    model = MODEL_CLASS.from_pretrained(args.model_type, config=config)

jeremytanjianle avatar Jan 07 '22 04:01 jeremytanjianle