NATTEN
NATTEN copied to clipboard
Arbitrary masks
Congratulations on shipping FNA backward! Looking forward to using it.
On another note: would it be possible to support arbitrary masking?
MaskDiT outperformed regular DiT, with a 70% reduction in training time, by masking-out 50% of patches (and predicting those via an auxilliary reconstruction loss + encoder-decoder architecture):
https://arxiv.org/abs/2306.09305v2
Perhaps there's also an opportunity to exploit sparsity (some regions may not require computation), but I think even without such optimization, arbitrary masking would still be useful due to enabling new training objectives.
Note: arbitrary masks/biases is something the PyTorch team are attempting with templated attention:
https://github.com/pytorch/pytorch/pull/121845
(hence the questions in https://github.com/SHI-Labs/NATTEN/issues/120 about whether a generic method could achieve the same effect as NATTEN performantly).