[WIP] Improve LongformerAttention performance
Description:
This work is in progress. (1) Reduce memory of sequence index from B x S to S (2) merge add bias and transpose into one kernel. (3) reduce computation of global_q GEMM with number of max global tokens.
TODO: build global index per model instead of per node. Need update the interface to add optional global index input/output.
Experiments Experiments on longformer-base-4096, batch_size=1,global_tokens=16, float16 on T4 GPU. Result is average latency in ms. MS - use separate stream for memory copy ABT1 - add bias transpose of half ABT2 - add bias transpose with half2 for half input, or float2 for float input ABT4 - add bias transpose with half4 for half input, or float4 for float input F0 - use format 0, where global_bias has bias only for global_q. Default is format 1, where globa_bias has bias for global_q, global_k and global_v. CGQ - compact global_q
| Sequence | baseline | MS | ABT1 | ABT1+F0 | ABT1+F0+CGQ | ABT2+F0+CGQ | ABT2 | ABT4 |
|---|---|---|---|---|---|---|---|---|
| 512 | 10.16 | 10.33 | 9.72 | 10.45 | 10.08 | 9.86 | 9.3 | 9.59 |
| 1024 | 21.93 | 22.12 | 21.73 | 22.32 | 21.82 | 21.78 | 20.85 | 21.57 |
| 2048 | 43.43 | 43.74 | 48.24 | 50.28 | 48.39 | 43.36 | 42.21 | 43.35 |
| 4096 | 88.99 | 89.17 | 89.25 | 91.74 | 88.98 | 86.35 | 85.79 | 89.35 |
Conclusions: (a) Use separate stream to copy data: latency was slightly increased. It might because the data is small, and separate stream need extra synchronization. Reverted. (b) AddBiasTranspose on half helps on shorter sequence length. (Based on 'baseline' vs 'ABT1') AddBiasTranspose on half2 helps all sequence lengths. (Based on 'baseline' vs 'ABT2') (c) Merge q, k and v weights is better than splitting them: run GEMM on a merged matrix is better than 3 GEMMs one by one. (d) Apply AddBiasTranspose on 3 (Q,K,V) + 3 (Global_Q, Global_K, Global_V) full matrices is better than 5 (Q, K, V, Global_K, Global_V) + 1 (Global_Q) full matrices. (Based on 'ABT1' vs 'ABT1+F0') (f) Use compact global_q reduces latency but it need a new format which increases latency. The overall impact is neutral (Based on 'ABT1' vs 'ABT1+F0+CGQ') (g) Use half2 could significantly reduce latency. (Based on 'ABT1+F0+CGQ' vs 'ABT2+F0+CGQ'). It is because GPU hardware provides load instructions for 32-bit, 64-bit and 128-bit data, which maps to the half2, float2, and float4 data types. TODO: test float model to see impact of float2 and float4. (h) use half4 is slightly slower than half2.
Motivation and Context
- Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here.
This pull request fixes 1 alert when merging 136c7a7f67eb86bce6bcc901a12a1beb69718621 into e810480403ebdd5f3438431ef2fc060c70a52b66 - view on LGTM.com
fixed alerts:
- 1 for Unused local variable
This pull request fixes 1 alert when merging d9ca1c634381da694ac5a5c36033541e1090425f into b2382dc43a92910c0e57fd4bfc31ed8649cd8521 - view on LGTM.com
fixed alerts:
- 1 for Unused local variable
This pull request fixes 1 alert when merging d9546a99b745496f8196dbcd63c54f97eef6c1be into 018fba9b7463664e14e4038b10545a4eb6037516 - view on LGTM.com
fixed alerts:
- 1 for Unused local variable
This pull request fixes 1 alert when merging 9ec6bbdd07e2338e3ec7dc164d8b72dbff878520 into 24eab921bee4dbf9002942fbc514b2b2d9ee3a64 - view on LGTM.com
fixed alerts:
- 1 for Unused local variable
This pull request fixes 1 alert when merging 6db296d481cb9f5792bb2b0a045ae44fbbba4bcd into 95f2a3e7e0adfdd858333313ea29d8b3eb1e5271 - view on LGTM.com
fixed alerts:
- 1 for Unused local variable
This pull request fixes 1 alert when merging 3991369d3f50eef7ed6ef52d8bcae70edcc6c53c into 95f2a3e7e0adfdd858333313ea29d8b3eb1e5271 - view on LGTM.com
fixed alerts:
- 1 for Unused local variable
This pull request fixes 1 alert when merging f3c7fe25b85644c91349f8ea7352fa03900ae2be into 95f2a3e7e0adfdd858333313ea29d8b3eb1e5271 - view on LGTM.com
fixed alerts:
- 1 for Unused local variable
This pull request fixes 1 alert when merging c1f5fd0d786db15a27f2a95c2168e4318d5c9e68 into 95f2a3e7e0adfdd858333313ea29d8b3eb1e5271 - view on LGTM.com
fixed alerts:
- 1 for Unused local variable
This pull request fixes 11 alerts when merging 432c4f06500dcfdf8a958e0bd476fbf2910adcaf into 616677104a0b16461c3958e3952017b423cc97fb - view on LGTM.com
fixed alerts:
- 7 for Unused import
- 2 for Unused local variable
- 1 for Module is imported more than once
- 1 for Module is imported with 'import' and 'import from'
This pull request introduces 1 alert and fixes 11 when merging 0f7010e8351f151e49f20231df673b2d32ad740e into eb6aa861cfa7295ee9f7145db44aaec708e8ce1c - view on LGTM.com
new alerts:
- 1 for Except block handles 'BaseException'
fixed alerts:
- 7 for Unused import
- 2 for Unused local variable
- 1 for Module is imported more than once
- 1 for Module is imported with 'import' and 'import from'
This pull request introduces 1 alert and fixes 17 when merging ad0c09c1a58b8632ce9fe494a19cacdaac37e25c into eb6aa861cfa7295ee9f7145db44aaec708e8ce1c - view on LGTM.com
new alerts:
- 1 for Except block handles 'BaseException'
fixed alerts:
- 13 for Unused import
- 2 for Unused local variable
- 1 for Module is imported more than once
- 1 for Module is imported with 'import' and 'import from'
This pull request introduces 1 alert and fixes 17 when merging ebfafd921a6d0eaad2614c4aa10ed8895941160d into eb6aa861cfa7295ee9f7145db44aaec708e8ce1c - view on LGTM.com
new alerts:
- 1 for Except block handles 'BaseException'
fixed alerts:
- 13 for Unused import
- 2 for Unused local variable
- 1 for Module is imported more than once
- 1 for Module is imported with 'import' and 'import from'
This pull request introduces 1 alert and fixes 17 when merging 9f91a60d6dec087171149b0e6a22b3af1056c388 into eb6aa861cfa7295ee9f7145db44aaec708e8ce1c - view on LGTM.com
new alerts:
- 1 for Except block handles 'BaseException'
fixed alerts:
- 13 for Unused import
- 2 for Unused local variable
- 1 for Module is imported more than once
- 1 for Module is imported with 'import' and 'import from'
This pull request introduces 1 alert and fixes 17 when merging 24a3f3068b0b7e480a248474e213b628ae6b8286 into 7df2e8c5ccbe6d05ebdddd59b971358fd14c48b4 - view on LGTM.com
new alerts:
- 1 for Except block handles 'BaseException'
fixed alerts:
- 13 for Unused import
- 2 for Unused local variable
- 1 for Module is imported more than once
- 1 for Module is imported with 'import' and 'import from'
This pull request introduces 1 alert and fixes 17 when merging c48cdc901fe37f16aabb17ae920613c5e529678f into 7df2e8c5ccbe6d05ebdddd59b971358fd14c48b4 - view on LGTM.com
new alerts:
- 1 for Except block handles 'BaseException'
fixed alerts:
- 13 for Unused import
- 2 for Unused local variable
- 1 for Module is imported more than once
- 1 for Module is imported with 'import' and 'import from'