tabnet
tabnet copied to clipboard
feat: add attentive embeddings
Any change needs to be discussed before proceeding. Failure to do so may result in the rejection of the pull request.
What kind of change does this PR introduce?
This PR aims at improving attention with embeddings. As the embedding size grows the attention module becomes less relevant, because attention does not know about what columns are from the same embeddings.
In order to solve this problem, this PR adds a mask post processing that takes (for now) the max of attention given to any embedding (mean could be tried as well).
Does this PR introduce a breaking change? Everything is internal and invisible to the end users.
What needs to be documented once your changes are merged?
Closing issues
Put closes #XXXX in your comment to auto-close the issue that your PR fixes (if such).
Just to point out, the torch-scatter dependency introduced here is making it a bit tricky for me to reproduce the Forest Cover Type tests from #217 because:
- We seem to need to explicitly select the correct install version for your PyTorch & CUDA combination (so things aren't as simple as
pyproject.tomlmakes it look), and - I was testing on a PyTorch v1.4 env with CUDA v10.2, which isn't a supported combination for torch-scatter 😞