attention-is-all-you-need-pytorch
attention-is-all-you-need-pytorch copied to clipboard
How to faster the training process
I run this version transformer to train my nmt system with four GPUS. The data is about 3 million parallel between Chinese and Japanese. But I found the time costing is too large, it may cost 14 hours per epoch with four k80. I add the code just like nn.DataParallel() to put the data on GPUS and I found cuda0's is easy out of memory with batch_size large than 80. How to fix faster the training process or make the batch_size larger? Any one can help?
Hi,
I also want to train the model in parallel.
How to implement nn.DataParallel()
?
Once I set that transformer=nn.Dataparallel(transformer).to(device)
in train.py
, the training process then kept in 0% and could not be interrupted.