SpeechGPT
SpeechGPT copied to clipboard
stage 2: dimension mismatch
Hello,
I encounter this issue when running cm_sft.py
0: File "SpeechGPT/speechgpt/src/train/cm_sft_modified.py", line 460, in <module>
0: train()
0: File "SpeechGPT/speechgpt/src/train/cm_sft_modified.py", line 428, in train
0: train_result = trainer.train(resume_from_checkpoint=checkpoint)
0: File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1539, in train
0: return inner_training_loop(
0: File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1836, in _inner_training_loop
0: for step, inputs in enumerate(epoch_iterator):
0: File "/usr/local/lib/python3.10/dist-packages/accelerate/data_loader.py", line 560, in __iter__
0: next_batch, next_batch_info = self._fetch_batches(main_iterator)
0: File "/usr/local/lib/python3.10/dist-packages/accelerate/data_loader.py", line 523, in _fetch_batches
0: batches.append(next(iterator))
0: File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 630, in __next__
0: data = self._next_data()
0: File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py", line 674, in _next_data
0: data = self._dataset_fetcher.fetch(index) # may raise StopIteration
0: File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py", line 32, in fetch
0: data.append(next(self.dataset_iter))
0: File "/usr/local/lib/python3.10/dist-packages/datasets/iterable_dataset.py", line 1384, in __iter__
0: for key, example in ex_iterable:
0: File "/usr/local/lib/python3.10/dist-packages/datasets/iterable_dataset.py", line 679, in __iter__
0: yield from self._iter()
0: File "/usr/local/lib/python3.10/dist-packages/datasets/iterable_dataset.py", line 694, in _iter
0: for key, example in iterator:
0: File "/usr/local/lib/python3.10/dist-packages/datasets/iterable_dataset.py", line 679, in __iter__
0: yield from self._iter()
0: File "/usr/local/lib/python3.10/dist-packages/datasets/iterable_dataset.py", line 731, in _iter
0: raise ValueError(
0: ValueError: Column lengths mismatch: columns ['input_ids', 'attention_mask'] have length [512, 512] while prefix has length 1000.