mlx icon indicating copy to clipboard operation
mlx copied to clipboard

[Feature Request] Support for (efficient) blockwise-diagonal attention

Open andersonbcdefg opened this issue 1 year ago • 0 comments

Block-diagonal attention allows keys/queries from within the same sequence to attend to each other and not other sequences, even when multiple sequences are packed/concatenated together. This is really useful to have for things like e.g. training BERTs or GPTs with packed sequences, removing padding to avoid wasting computation. Xformers supports this (see image below) via a special attention mask passed to their memory-efficient attention. The naive way to do this would be to construct the mask yourself and pass it into nn.MultiHeadAttention, but (I suspect) this would waste lots of computation because large swathes of the mask are -inf. MLX is new so I don't have great intuitions about what would be fast/slow here, but I imagine there would be a quick memory efficient way to do this with a special kernel.

Screenshot 2023-12-23 at 11 46 25 AM

andersonbcdefg avatar Dec 23 '23 17:12 andersonbcdefg