TradeTheEvent
TradeTheEvent copied to clipboard
I got the error when I run run_event.py
Traceback (most recent call last):
File "run_event.py", line 489, in self.bias)
File "/data1/lht/anaconda3/envs/tte/lib/python3.7/site-packages/torch/nn/functional.py", line 1753, in linear
return torch._C._nn.linear(input, weight, bias)
RuntimeError: mat1 dim 1 must match mat2 dim 0
Is the error caused by networks, or the version of torch is not valid?
run_event.py
seems untested.
Amongst other issues, the config.num_labels
is not fixed according to the argparser but to BERT's default config.
Insert the following code in run_event.py
roughly around line 397:
logger.info(
'Total training batch size: {}'.format(args.per_gpu_batch_size * args.gradient_accumulation_steps * args.n_gpu))
config = BertConfig.from_pretrained(args.model_type)
# config.num_labels = 12
config.num_labels = args.num_labels # insert this line
config.max_seq_length = args.max_seq_length # insert this line
model = MODEL_CLASS.from_pretrained(args.model_type, config=config)