textgen icon indicating copy to clipboard operation
textgen copied to clipboard

Bart长文本训练问题

Open YoungChanYY opened this issue 2 years ago • 5 comments
trafficstars

我用Bart训练代码,每个训练数据都为:输入文本约1000字符,输出文本长3-5万字符。训练几个epoch后会出错,错误信息如下所示。 但是控制输入和输出的字符长度,比如都为100字符左右,则训练正常,没有报错。

请问一下:Bart模型的输入输出长度有什么要求吗,这应该是内部embedding维度出错了吧。谢谢。

RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling cublasSgemmStridedBatched( handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches)

YoungChanYY avatar Apr 11 '23 12:04 YoungChanYY

出错的位置好像是在predict位置。当取消在训练过程中进行eval处理时,训练得以正常进行。大佬

Traceback (most recent call last): File "train_bart_text2abc.py", line 180, in main() File "train_bart_text2abc.py", line 163, in main model.train_model(train_df, eval_data=eval_df, split_on_space=True, matches=count_matches) File "textgen/seq2seq/bart_seq2seq_model.py", line 452, in train_model **kwargs, File "textgen/seq2seq/bart_seq2seq_model.py", line 983, in train **kwargs, File "textgen/seq2seq/bart_seq2seq_model.py", line 1153, in eval_model preds = self.predict(to_predict, split_on_space=split_on_space) File "textgen/seq2seq/bart_seq2seq_model.py", line 1310, in predict num_return_sequences=self.args.num_return_sequences, File "/usr/local/lib/python3.7/dist-packages/torch/autograd/grad_mode.py", line 27, in decorate_context return func(*args, **kwargs) File "/usr/local/lib/python3.7/dist-packages/transformers/generation/utils.py", line 1400, in generate **model_kwargs, File "/usr/local/lib/python3.7/dist-packages/transformers/generation/utils.py", line 2183, in greedy_search output_hidden_states=output_hidden_states, File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, **kwargs) File "/usr/local/lib/python3.7/dist-packages/transformers/models/bart/modeling_bart.py", line 1389, in forward return_dict=return_dict, File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, **kwargs) File "/usr/local/lib/python3.7/dist-packages/transformers/models/bart/modeling_bart.py", line 1268, in forward return_dict=return_dict, File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, **kwargs) File "/usr/local/lib/python3.7/dist-packages/transformers/models/bart/modeling_bart.py", line 1124, in forward use_cache=use_cache, File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, **kwargs) File "/usr/local/lib/python3.7/dist-packages/transformers/models/bart/modeling_bart.py", line 431, in forward output_attentions=output_attentions, File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, **kwargs) File "/usr/local/lib/python3.7/dist-packages/transformers/models/bart/modeling_bart.py", line 275, in forward attn_output = torch.bmm(attn_probs, value_states) RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling cublasSgemmStridedBatched( handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches) ../aten/src/ATen/native/cuda/Indexing.cu:650: indexSelectSmallIndex: block: [0,0,0], thread: [0,0,0] Assertion srcIndex < srcSelectDimSize failed.

YoungChanYY avatar Apr 11 '23 13:04 YoungChanYY

我看看evaluate的逻辑

shibing624 avatar Apr 11 '23 14:04 shibing624

多谢。

我看到另一处地方,应该有些问题: 在textgen/seq2seq/bart_seq2seq_utils.py的preprocess_data_bart(data)函数中,对target_ids 数据处理的问题和建议如下,大佬看看对不对。谢谢!

def preprocess_data_bart(data): input_text, target_text, tokenizer, args = data ...... target_ids = tokenizer.batch_encode_plus( [target_text], # max_length=args.max_seq_length, #原代码 max_length=args.max_length, #建议代码 padding="max_length", return_tensors="pt", truncation=True, )

YoungChanYY avatar Apr 12 '23 06:04 YoungChanYY

对的,fixed: https://github.com/shibing624/textgen/commit/7a0be5931234262165d32fc0f915af822e0b1665

shibing624 avatar Apr 12 '23 10:04 shibing624

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.(由于长期不活动,机器人自动关闭此问题,如果需要欢迎提问)

stale[bot] avatar Dec 27 '23 07:12 stale[bot]