tabnet icon indicating copy to clipboard operation
tabnet copied to clipboard

Switch to half precision with torch1.6

Open Optimox opened this issue 5 years ago • 1 comments
trafficstars

Feature request

It's now possible directly though pytorch to use mixed precision training (https://arxiv.org/pdf/1710.03740.pdf) As tabular values rarely change at 1e-10 scales and since tabnet is not so deep this would potentially not harm the results performances but could make the training faster with potentially a nice order of magnitude.

What is the expected behavior?

As some GPU don't allow mixed precision we should probably add a parameter in the fit format so that user can choose whether to use FP32 or FP16.

What is motivation or use case for adding/changing the behavior? Let's speed things up!

How should this be implemented in your opinion? It seems quite simple once you upgrade to torch 1.6, documentation seems quite straightforward https://pytorch.org/docs/master/notes/amp_examples.html

Are you willing to work on this yourself? I'll have a look a make some tests in the coming days but will probably wait for torch 1.6 to be out.

Optimox avatar Jun 24 '20 13:06 Optimox

Tried something here #350 but still WIP and no improvement on my GPU

Optimox avatar Dec 27 '21 16:12 Optimox