Training on GTX2080
Process 2 terminated with the following error:
Traceback (most recent call last):
File "/data3/liuhaogeng/anaconda3/envs/vits/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 20, in _wrap
fn(i, *args)
File "/data3/liuhaogeng/test/vits-main/train.py", line 120, in run
train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, [train_loader, None], None, None)
File "/data3/liuhaogeng/test/vits-main/train.py", line 138, in train_and_evaluate
for batch_idx, (x, x_lengths, spec, spec_lengths, y, y_lengths) in enumerate(train_loader):
File "/data3/liuhaogeng/anaconda3/envs/vits/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 363, in next
data = self._next_data()
File "/data3/liuhaogeng/anaconda3/envs/vits/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 989, in _next_data
return self._process_data(data)
File "/data3/liuhaogeng/anaconda3/envs/vits/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1014, in _process_data
data.reraise()
File "/data3/liuhaogeng/anaconda3/envs/vits/lib/python3.8/site-packages/torch/_utils.py", line 395, in reraise
raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in DataLoader worker process 2.
Original Traceback (most recent call last):
File "/data3/liuhaogeng/anaconda3/envs/vits/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 185, in _worker_loop
data = fetcher.fetch(index)
File "/data3/liuhaogeng/anaconda3/envs/vits/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/data3/liuhaogeng/anaconda3/envs/vits/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in
I set the batch size to 16
I met the same error. This error is caused because that function TextAudioLoader.get_audio saved empty spectrograms file(.spec.pt). (Maybe caused by multi-processing?) Torch will throw this error when load empty file. You can compute and save spectrograms before your training. The following is spectrogram compute code separated from vits. Change base to your wavs's folder path.
from scipy.io.wavfile import read
import torch
import numpy as np
import os
from multiprocessing import Pool
from tqdm import tqdm
# Change here
base="Your wavs's folder path"
hann_window = {}
def load_wav_to_torch(full_path):
sampling_rate, data = read(full_path)
# data, sampling_rate = librosa.load(full_path)
return torch.FloatTensor(data.astype(np.float32)), sampling_rate
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
if torch.min(y) < -1.:
print('min value is ', torch.min(y))
if torch.max(y) > 1.:
print('max value is ', torch.max(y))
global hann_window
dtype_device = str(y.dtype) + '_' + str(y.device)
wnsize_dtype_device = str(win_size) + '_' + dtype_device
if wnsize_dtype_device not in hann_window:
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
y = y.squeeze(1)
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
center=center, pad_mode='reflect', normalized=False, onesided=True)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
return spec
def get_audio(filename):
max_wave_length = 32768.0
filter_length = 1024
hop_length = 256
win_length = 1024
audio, sampling_rate = load_wav_to_torch(filename)
audio_norm = audio / max_wave_length
audio_norm = audio_norm.unsqueeze(0)
spec_filename = filename.replace(".wav", ".spec.pt")
spec = spectrogram_torch(audio_norm, filter_length,
sampling_rate, hop_length, win_length,
center=False)
spec = torch.squeeze(spec, 0)
torch.save(spec, spec_filename)
if __name__=="__main__":
waves = []
for wav_name in os.listdir(base):
wav_path = os.path.join(base, wav_name)
if wav_path.endswith(".wav"):
waves.append(wav_path)
with Pool(16) as p:
print(list((tqdm(p.imap(get_audio,waves),total=len(waves)))))
@lpdink, Thanks, I had this error too, and your advice helped me a lot!