gluon-nlp icon indicating copy to clipboard operation
gluon-nlp copied to clipboard

[Proposal] Unified Interface/Implementation for Sparse Attention

Open ZiyueHuang opened this issue 4 years ago • 4 comments

Currently several schemes of sparse attention (e.g. block-sparse, sliding window) relies on the handcrafted kernels, and it takes plenty of effort to implement new schemes (for research or other purpose). We may consider adopting a unified interface and implementation based on SpMM.

We can require an attention mask (defined by the user, and could be dynamically learned/generated in runtime) as the input. The attention mask is a CSR/COO matrix (or other) of shape (n, n), where n denotes the sequence length, the i-th token should attend to the j-th token if atten_mask[i, j] = 1.

  • Given query and key, to compute attention_score, we just need compute a dot product (between i-th token and j-th token) for each non-zero coordinate in the attention mask. The output, attention_score, is also in CSR/COO format, same as atten_mask.
  • Given attention_score, to compute attention_weight, we need a softmax kernel over the sparse tensors, maybe we can reuse the kernel here https://github.com/google-research/sputnik/blob/master/sputnik/softmax/sparse_softmax.h. The output, attention_weight, is also in CSR/COO format, same as attention_score.
  • Given attention_weight and value, to compute context, we need a matrix multiplication kernel between CSR/COO tensor and the dense tensor, which may exist in CuSparse or we can reuse the kernels we develop in MXNet Sparse module years ago.

This paper could be useful https://arxiv.org/pdf/2006.10901.pdf, and the source code https://github.com/google-research/sputnik.

Then, to provide sliding window attention, block-sparse attention or other schemes, we just need generate the corresponding attention mask, which should be quite easy.

cc @szhengac @sxjscience

ZiyueHuang avatar Oct 27 '20 12:10 ZiyueHuang

I think the reason to use BlockSparse rather than General Sparse is that the speed can usually be faster since we are dealing with a block of elements at one time. Nevertheless, we should start to profile different solutions. Since it only involves kernels, one way is to use PyTorch custom OP to quickly prototype.

sxjscience avatar Oct 27 '20 15:10 sxjscience

Also, the DGL team has been profiling the performance of sparse kernels for some time. So I also ping @yzh119 and @zheng-da here.

sxjscience avatar Oct 27 '20 16:10 sxjscience

This is not a killing problem, as the backend can switch to faster kernel if the attention pattern has optimized handcrafted kernel.

Block-sparse attention seems not very appealing to me, as introducing block-partition will break the sequence and the local context (at least for the tokens on the boundary). Maybe it would be useful when the sequence has some known structure to be exploited. And it lacks benchmarks for NLP tasks (the benchmark in DeepSpeed blog seems very limited and the configuration is missing). Actually to implement the block-sparse attention we can just split the input over the sequence dimension...

I think for research purpose, the easy-to-try property should be most appealing. For example, if someone wants to improve the sliding window attention, say, by trying variable window length for each token, which seems natural as some tokens may only need a much shorter/longer window length. Or maybe trying to build/learn an attention mask efficiently by some method (like the MIPS approach in Reformer). It would be great for him to try a package which can help him quickly try the idea. Meanwhile, it can also save effort for the developers.

ZiyueHuang avatar Oct 27 '20 16:10 ZiyueHuang

Just checked the source code of https://github.com/google-research/sputnik, we may try to implement our own version based on this package. Also CC @leezu @szha @barry-jin @MoisesHer

sxjscience avatar Oct 27 '20 16:10 sxjscience