char-rnn.pytorch
char-rnn.pytorch copied to clipboard
Ak/fix train feeding
- Fix index out of range error.
- About 8x (for me) training speed-up by feeding whole sample sequence through CuDNN, not char-by-char.
- Inference: GPU memory requirement not growing with generated seq length anymore -- by dropping Variable history in inference.
For what it's worth, using the pre-built pytorch 1.3 on macOS (where CUDA is not available), on an early-2015 MacBook Pro, this patch doesn't seem to improve performance on CPU-- instead, it makes training more than twice as slow. [Edited to add: Tested with CPU on Linux, and got similar results there as well.]
After 100 iterations using the default hyperparameters: master: ~1.09s/iteration fix-train-feeding: ~3.71s/iteration
Not sure whether this is an issue with this specific CPU, an inefficiency in PyTorch itself, or something that could be further improved in the char-rnn code, but it's definitely worth pointing out either way.