skorch
skorch copied to clipboard
Mixed precision training
Adding support for mixed precision training on Volta GPUs would be nice to have. There are several tricks to get this to work, which are outlined here: https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html
I haven't tried mixed precision training yet. Is there anything in skorch that prevents mixed precision training? AFAIR, we do not convert dtypes anywhere.
I think most of the tricks can be implemented with a callback. If I had a Volta GPU to take advantage of mixed precision, I would take on this issue.
For me the first thing to test would be whether mixed precision training works at all or whether we have some parts in skorch that prevent that. But I don't have access to a Volta card either to test that.
Regarding the tricks, it looks to me that they should be handled by the user and not by skorch. We would need to verify that skorch does not prevent the user from applying the tricks, though.
I think with the addition of on_grad_completed, this feature can be implemented with a skorch callback. I can put together a proof of concept of this feature. It will work on a non-Volta card, but will have little benefit in speed.
The main selling point of mixed precision, is that it allows for double the batch size therefore lowering training time without increasing the loss.
If you have a PoC I can test it on a V100 AWS instance and report back / fix obvious things :)