TSTNN icon indicating copy to clipboard operation
TSTNN copied to clipboard

TSTNN network training is very slow

Open Fangbo0506 opened this issue 3 years ago • 4 comments

I try to reproduce the TSTNN network. The data set is 50 hours, and each epoch is iterated 12000 times. The training is very slow. An epoch requires about 8 hours of training. I wonder if it is the same when you train?

Fangbo0506 avatar Sep 22 '21 07:09 Fangbo0506

Hi, thank you for reaching out. In my implementation, I set batch size as 2 and each epoch is iterated about 5100 steps. The training time of each epoch is about 50 mins in two GTX 1080ti gpus. May I know about if you use the Voice_Bank dataset with 28 speakers and how many gpus you used for training?

key2miao avatar Sep 22 '21 18:09 key2miao

I also use GTX 1080ti gpus. The dataset is the noisy dataset of the DNS Challenge I synthesized, not the Voice_Bank dataset with 28 speakers.And I haven’t changed your network, I haven’t found the reason yet.And in metric.py ,there is a mistake. def get_pesq(ref, deg, sr): score = pesq(sr, ref, deg, 'wb') return score And loss function do not the same as you say in your paper, loss = 0.4 * loss_time + 0.6 * loss_freq.Maybe you uploaded the wrong file version...?I just want to reproduce this paper.

Fangbo0506 avatar Sep 24 '21 02:09 Fangbo0506

I found that there are a few lines of code in train.py that take a lot of time, but we don’t really need this piece of data. if (index + 1) % eval_steps == 0: ave_train_loss = total_train_loss / count

        # validation
        avg_eval_loss = validate(model, validation_loader)
        model.train()

        print('Epoch [%d/%d], Iter [%d/%d],  ( TrainLoss: %.4f | EvalLoss: %.4f )' % (
        epoch + 1, max_epochs, index + 1, len(train_loader), ave_train_loss, avg_eval_loss))

        count = 0
        total_train_loss = 0.0

If i delete these codes, i will spend about 50min in each epoch.

Fangbo0506 avatar Sep 24 '21 07:09 Fangbo0506

I also use GTX 1080ti gpus. The dataset is the noisy dataset of the DNS Challenge I synthesized, not the Voice_Bank dataset with 28 speakers.And I haven’t changed your network, I haven’t found the reason yet.And in metric.py ,there is a mistake. def get_pesq(ref, deg, sr): score = pesq(sr, ref, deg, 'wb') return score And loss function do not the same as you say in your paper, loss = 0.4 * loss_time + 0.6 * loss_freq.Maybe you uploaded the wrong file version...?I just want to reproduce this paper.

Please change 'loss = 0.4 * loss_time + 0.6 * loss_freq' according to the paper. Because I also tried different weight between time loss and frequency loss afterwards. Sorry for this confusion.

key2miao avatar Jan 01 '22 22:01 key2miao