skorch icon indicating copy to clipboard operation
skorch copied to clipboard

Mixed precision training

Open thomasjpfan opened this issue 6 years ago • 5 comments

Adding support for mixed precision training on Volta GPUs would be nice to have. There are several tricks to get this to work, which are outlined here: https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html

thomasjpfan avatar Jan 28 '19 22:01 thomasjpfan

I haven't tried mixed precision training yet. Is there anything in skorch that prevents mixed precision training? AFAIR, we do not convert dtypes anywhere.

BenjaminBossan avatar Jan 29 '19 11:01 BenjaminBossan

I think most of the tricks can be implemented with a callback. If I had a Volta GPU to take advantage of mixed precision, I would take on this issue.

thomasjpfan avatar Jan 29 '19 19:01 thomasjpfan

For me the first thing to test would be whether mixed precision training works at all or whether we have some parts in skorch that prevent that. But I don't have access to a Volta card either to test that.

Regarding the tricks, it looks to me that they should be handled by the user and not by skorch. We would need to verify that skorch does not prevent the user from applying the tricks, though.

BenjaminBossan avatar Jan 29 '19 20:01 BenjaminBossan

I think with the addition of on_grad_completed, this feature can be implemented with a skorch callback. I can put together a proof of concept of this feature. It will work on a non-Volta card, but will have little benefit in speed.

The main selling point of mixed precision, is that it allows for double the batch size therefore lowering training time without increasing the loss.

thomasjpfan avatar Jan 29 '19 21:01 thomasjpfan

If you have a PoC I can test it on a V100 AWS instance and report back / fix obvious things :)

ottonemo avatar Mar 23 '19 13:03 ottonemo