TSTNN
TSTNN copied to clipboard
TSTNN network training is very slow
I try to reproduce the TSTNN network. The data set is 50 hours, and each epoch is iterated 12000 times. The training is very slow. An epoch requires about 8 hours of training. I wonder if it is the same when you train?
Hi, thank you for reaching out. In my implementation, I set batch size as 2 and each epoch is iterated about 5100 steps. The training time of each epoch is about 50 mins in two GTX 1080ti gpus. May I know about if you use the Voice_Bank dataset with 28 speakers and how many gpus you used for training?
I also use GTX 1080ti gpus. The dataset is the noisy dataset of the DNS Challenge I synthesized, not the Voice_Bank dataset with 28 speakers.And I haven’t changed your network, I haven’t found the reason yet.And in metric.py ,there is a mistake. def get_pesq(ref, deg, sr): score = pesq(sr, ref, deg, 'wb') return score And loss function do not the same as you say in your paper, loss = 0.4 * loss_time + 0.6 * loss_freq.Maybe you uploaded the wrong file version...?I just want to reproduce this paper.
I found that there are a few lines of code in train.py that take a lot of time, but we don’t really need this piece of data. if (index + 1) % eval_steps == 0: ave_train_loss = total_train_loss / count
# validation
avg_eval_loss = validate(model, validation_loader)
model.train()
print('Epoch [%d/%d], Iter [%d/%d], ( TrainLoss: %.4f | EvalLoss: %.4f )' % (
epoch + 1, max_epochs, index + 1, len(train_loader), ave_train_loss, avg_eval_loss))
count = 0
total_train_loss = 0.0
If i delete these codes, i will spend about 50min in each epoch.
I also use GTX 1080ti gpus. The dataset is the noisy dataset of the DNS Challenge I synthesized, not the Voice_Bank dataset with 28 speakers.And I haven’t changed your network, I haven’t found the reason yet.And in metric.py ,there is a mistake. def get_pesq(ref, deg, sr): score = pesq(sr, ref, deg, 'wb') return score And loss function do not the same as you say in your paper, loss = 0.4 * loss_time + 0.6 * loss_freq.Maybe you uploaded the wrong file version...?I just want to reproduce this paper.
Please change 'loss = 0.4 * loss_time + 0.6 * loss_freq' according to the paper. Because I also tried different weight between time loss and frequency loss afterwards. Sorry for this confusion.