saint
saint copied to clipboard
Unofficial Pytorch implementation of SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pretraining https://arxiv.org/abs/2106.01342
Inter sample attention now inherits from nn.MultiheadAttention which includes optimizations such as flash attention
When reshaping queries, keys, and values in the intersample() function, shouldn’t they be changed to (1, h, b, n*d)? The code has a structure in which 8 heads per batch...
also, task is not specified in the `Metric` class