Traceback (most recent call last):
File "/home/yons/workfiles/codes/opencodes/BertWithPretrained/Tasks/TaskForChineseNER.py", line 315, in
train(config)
File "/home/yons/workfiles/codes/opencodes/BertWithPretrained/Tasks/TaskForChineseNER.py", line 132, in train
loss, logits = model(input_ids=token_ids, # [src_len, batch_size]
File "/home/yons/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/yons/anaconda3/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py", line 168, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File "/home/yons/anaconda3/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py", line 178, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File "/home/yons/anaconda3/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py", line 86, in parallel_apply
output.reraise()
File "/home/yons/anaconda3/lib/python3.9/site-packages/torch/_utils.py", line 434, in reraise
raise exception
RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
File "/home/yons/anaconda3/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
output = module(*input, **kwargs)
File "/home/yons/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/yons/workfiles/codes/opencodes/BertWithPretrained/Tasks/../model/DownstreamTasks/BertForTokenClassification.py", line 32, in forward
_, all_encoder_outputs = self.bert(input_ids=input_ids,
File "/home/yons/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/yons/workfiles/codes/opencodes/BertWithPretrained/Tasks/../model/BasicBert/Bert.py", line 290, in forward
all_encoder_outputs = self.bert_encoder(embedding_output,
File "/home/yons/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/yons/workfiles/codes/opencodes/BertWithPretrained/Tasks/../model/BasicBert/Bert.py", line 190, in forward
layer_output = layer_module(layer_output,
File "/home/yons/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/yons/workfiles/codes/opencodes/BertWithPretrained/Tasks/../model/BasicBert/Bert.py", line 162, in forward
attention_output = self.bert_attention(hidden_states, attention_mask)
File "/home/yons/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/yons/workfiles/codes/opencodes/BertWithPretrained/Tasks/../model/BasicBert/Bert.py", line 93, in forward
self_outputs = self.self(hidden_states,
File "/home/yons/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/yons/workfiles/codes/opencodes/BertWithPretrained/Tasks/../model/BasicBert/Bert.py", line 56, in forward
return self.multi_head_attention(query, key, value, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
File "/home/yons/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/yons/workfiles/codes/opencodes/BertWithPretrained/Tasks/../model/BasicBert/MyTransformer.py", line 296, in forward
return multi_head_attention_forward(query, key, value, self.num_heads,
File "/home/yons/workfiles/codes/opencodes/BertWithPretrained/Tasks/../model/BasicBert/MyTransformer.py", line 360, in multi_head_attention_forward
attn_output_weights = attn_output_weights.masked_fill(
RuntimeError: The size of tensor a (367) must match the size of tensor b (184) at non-singleton dimension 3
https://tokudayo.github.io/multiprocessing-in-python-and-torch/
https://tokudayo.github.io/distributed-communication-in-pytorch/
https://tokudayo.github.io/torch-ddp/
https://zhuanlan.zhihu.com/p/350301395?utm_source=wechat_session&utm_medium=social&utm_oi=783331762284687360
可以参考一下这些文章,后续我抽空补上这部分代码
请问这个问题解决了吗,我在使用DP时同样遇到了这个问题,尝试了DDP也失败了