自定义数据集微调internlm2_5_7b_chat注意力shape报错
[rank0]: Traceback (most recent call last):
[rank0]: File "/home/xmyu/anaconda3/envs/xtuner-env/lib/python3.10/site-packages/xtuner/tools/train.py", line 360, in
打印了关键张量发现,再第一轮sample batch_size=32后,再继续训练seq_length 变为1,导致上述attention计算异常,是否是版本不匹配?torch=2.5.1, transformers=4.49.0
02/27 03:06:07 - mmengine - INFO - before_train in EvaluateChatHook. hidden s, am torch.Size([1, 40, 4096]) torch.Size([1, 1, 40, 40]) shape torch.Size([1, 1, 40, 40]) torch.Size([1, 32, 40, 40]) torch.Size([1, 1, 40, 40]) hidden s, am torch.Size([1, 40, 4096]) torch.Size([1, 1, 40, 40]) shape torch.Size([1, 1, 40, 40]) torch.Size([1, 32, 40, 40]) torch.Size([1, 1, 40, 40]) hidden s, am torch.Size([1, 40, 4096]) torch.Size([1, 1, 40, 40]) shape torch.Size([1, 1, 40, 40]) torch.Size([1, 32, 40, 40]) torch.Size([1, 1, 40, 40]) hidden s, am torch.Size([1, 40, 4096]) torch.Size([1, 1, 40, 40]) shape torch.Size([1, 1, 40, 40]) torch.Size([1, 32, 40, 40]) torch.Size([1, 1, 40, 40]) hidden s, am torch.Size([1, 40, 4096]) torch.Size([1, 1, 40, 40]) shape torch.Size([1, 1, 40, 40]) torch.Size([1, 32, 40, 40]) torch.Size([1, 1, 40, 40]) hidden s, am torch.Size([1, 40, 4096]) torch.Size([1, 1, 40, 40]) shape torch.Size([1, 1, 40, 40]) torch.Size([1, 32, 40, 40]) torch.Size([1, 1, 40, 40]) hidden s, am torch.Size([1, 40, 4096]) torch.Size([1, 1, 40, 40]) shape torch.Size([1, 1, 40, 40]) torch.Size([1, 32, 40, 40]) torch.Size([1, 1, 40, 40]) hidden s, am torch.Size([1, 40, 4096]) torch.Size([1, 1, 40, 40]) shape torch.Size([1, 1, 40, 40]) torch.Size([1, 32, 40, 40]) torch.Size([1, 1, 40, 40]) hidden s, am torch.Size([1, 40, 4096]) torch.Size([1, 1, 40, 40]) shape torch.Size([1, 1, 40, 40]) torch.Size([1, 32, 40, 40]) torch.Size([1, 1, 40, 40]) hidden s, am torch.Size([1, 40, 4096]) torch.Size([1, 1, 40, 40]) shape torch.Size([1, 1, 40, 40]) torch.Size([1, 32, 40, 40]) torch.Size([1, 1, 40, 40]) hidden s, am torch.Size([1, 40, 4096]) torch.Size([1, 1, 40, 40]) shape torch.Size([1, 1, 40, 40]) torch.Size([1, 32, 40, 40]) torch.Size([1, 1, 40, 40]) hidden s, am torch.Size([1, 40, 4096]) torch.Size([1, 1, 40, 40]) shape torch.Size([1, 1, 40, 40]) torch.Size([1, 32, 40, 40]) torch.Size([1, 1, 40, 40]) hidden s, am torch.Size([1, 40, 4096]) torch.Size([1, 1, 40, 40]) shape torch.Size([1, 1, 40, 40]) torch.Size([1, 32, 40, 40]) torch.Size([1, 1, 40, 40]) hidden s, am torch.Size([1, 40, 4096]) torch.Size([1, 1, 40, 40]) shape torch.Size([1, 1, 40, 40]) torch.Size([1, 32, 40, 40]) torch.Size([1, 1, 40, 40]) hidden s, am torch.Size([1, 40, 4096]) torch.Size([1, 1, 40, 40]) shape torch.Size([1, 1, 40, 40]) torch.Size([1, 32, 40, 40]) torch.Size([1, 1, 40, 40]) hidden s, am torch.Size([1, 40, 4096]) torch.Size([1, 1, 40, 40]) shape torch.Size([1, 1, 40, 40]) torch.Size([1, 32, 40, 40]) torch.Size([1, 1, 40, 40]) hidden s, am torch.Size([1, 40, 4096]) torch.Size([1, 1, 40, 40]) shape torch.Size([1, 1, 40, 40]) torch.Size([1, 32, 40, 40]) torch.Size([1, 1, 40, 40]) hidden s, am torch.Size([1, 40, 4096]) torch.Size([1, 1, 40, 40]) shape torch.Size([1, 1, 40, 40]) torch.Size([1, 32, 40, 40]) torch.Size([1, 1, 40, 40]) hidden s, am torch.Size([1, 40, 4096]) torch.Size([1, 1, 40, 40]) shape torch.Size([1, 1, 40, 40]) torch.Size([1, 32, 40, 40]) torch.Size([1, 1, 40, 40]) hidden s, am torch.Size([1, 40, 4096]) torch.Size([1, 1, 40, 40]) shape torch.Size([1, 1, 40, 40]) torch.Size([1, 32, 40, 40]) torch.Size([1, 1, 40, 40]) hidden s, am torch.Size([1, 40, 4096]) torch.Size([1, 1, 40, 40]) shape torch.Size([1, 1, 40, 40]) torch.Size([1, 32, 40, 40]) torch.Size([1, 1, 40, 40]) hidden s, am torch.Size([1, 40, 4096]) torch.Size([1, 1, 40, 40]) shape torch.Size([1, 1, 40, 40]) torch.Size([1, 32, 40, 40]) torch.Size([1, 1, 40, 40]) hidden s, am torch.Size([1, 40, 4096]) torch.Size([1, 1, 40, 40]) shape torch.Size([1, 1, 40, 40]) torch.Size([1, 32, 40, 40]) torch.Size([1, 1, 40, 40]) hidden s, am torch.Size([1, 40, 4096]) torch.Size([1, 1, 40, 40]) shape torch.Size([1, 1, 40, 40]) torch.Size([1, 32, 40, 40]) torch.Size([1, 1, 40, 40]) hidden s, am torch.Size([1, 40, 4096]) torch.Size([1, 1, 40, 40]) shape torch.Size([1, 1, 40, 40]) torch.Size([1, 32, 40, 40]) torch.Size([1, 1, 40, 40]) hidden s, am torch.Size([1, 40, 4096]) torch.Size([1, 1, 40, 40]) shape torch.Size([1, 1, 40, 40]) torch.Size([1, 32, 40, 40]) torch.Size([1, 1, 40, 40]) hidden s, am torch.Size([1, 40, 4096]) torch.Size([1, 1, 40, 40]) shape torch.Size([1, 1, 40, 40]) torch.Size([1, 32, 40, 40]) torch.Size([1, 1, 40, 40]) hidden s, am torch.Size([1, 40, 4096]) torch.Size([1, 1, 40, 40]) shape torch.Size([1, 1, 40, 40]) torch.Size([1, 32, 40, 40]) torch.Size([1, 1, 40, 40]) hidden s, am torch.Size([1, 40, 4096]) torch.Size([1, 1, 40, 40]) shape torch.Size([1, 1, 40, 40]) torch.Size([1, 32, 40, 40]) torch.Size([1, 1, 40, 40]) hidden s, am torch.Size([1, 40, 4096]) torch.Size([1, 1, 40, 40]) shape torch.Size([1, 1, 40, 40]) torch.Size([1, 32, 40, 40]) torch.Size([1, 1, 40, 40]) hidden s, am torch.Size([1, 40, 4096]) torch.Size([1, 1, 40, 40]) shape torch.Size([1, 1, 40, 40]) torch.Size([1, 32, 40, 40]) torch.Size([1, 1, 40, 40]) hidden s, am torch.Size([1, 40, 4096]) torch.Size([1, 1, 40, 40]) shape torch.Size([1, 1, 40, 40]) torch.Size([1, 32, 40, 40]) torch.Size([1, 1, 40, 40]) hidden s, am torch.Size([1, 1, 4096]) torch.Size([1, 1, 1, 40]) shape torch.Size([1, 1, 1, 40]) torch.Size([1, 32, 1, 41]) torch.Size([1, 1, 1, 40])