tabnet
tabnet copied to clipboard
Switch to half precision with torch1.6
Feature request
It's now possible directly though pytorch to use mixed precision training (https://arxiv.org/pdf/1710.03740.pdf) As tabular values rarely change at 1e-10 scales and since tabnet is not so deep this would potentially not harm the results performances but could make the training faster with potentially a nice order of magnitude.
What is the expected behavior?
As some GPU don't allow mixed precision we should probably add a parameter in the fit format so that user can choose whether to use FP32 or FP16.
What is motivation or use case for adding/changing the behavior? Let's speed things up!
How should this be implemented in your opinion? It seems quite simple once you upgrade to torch 1.6, documentation seems quite straightforward https://pytorch.org/docs/master/notes/amp_examples.html
Are you willing to work on this yourself? I'll have a look a make some tests in the coming days but will probably wait for torch 1.6 to be out.
Tried something here #350 but still WIP and no improvement on my GPU