xformers
xformers copied to clipboard
Add Triton Flash Attention
What does this PR do?
Adds Triton Flash Attention
Performance Compared to Vanilla
[--------- attention (attn_bias=<class 'NoneType'>) --------]
| optimized | vanilla
1 threads: --------------------------------------------------
f16 B=384, M=197, H=1, K=88 | 1260.4 | 372.9
b16 B=384, M=197, H=1, K=88 | 1269.9 | 375.3
f16 B=384, M=197, H=1, K=80 | 146.7 | 344.5
b16 B=384, M=197, H=1, K=80 | 149.2 | 346.6
f16 B=384, M=197, H=1, K=64 | 92.0 | 293.2
b16 B=384, M=197, H=1, K=64 | 94.2 | 295.1
f16 B=1024, M=197, H=1, K=88 | 3242.3 | 938.8
b16 B=1024, M=197, H=1, K=88 | 3264.9 | 945.1
f16 B=1024, M=197, H=1, K=80 | 348.7 | 864.5
b16 B=1024, M=197, H=1, K=80 | 354.3 | 871.5
f16 B=1024, M=197, H=1, K=64 | 213.7 | 729.9
b16 B=1024, M=197, H=1, K=64 | 222.1 | 735.5
f16 B=512, M=197, H=1, K=80 | 185.4 | 448.4
b16 B=512, M=197, H=1, K=80 | 188.5 | 451.6
f16 B=32, M=197, H=16, K=80 | 193.8 | 547.9
b16 B=32, M=197, H=16, K=80 | 193.7 | 550.6
f16 B=32, M=197, H=16, K=64 | 114.4 | 467.1
b16 B=32, M=197, H=16, K=64 | 120.8 | 469.2
f16 B=32, M=197, H=16, K=128 | 193.8 | 717.4
b16 B=32, M=197, H=16, K=128 | 195.1 | 720.3
f16 B=256, M=197, H=1, K=88 | 868.4 | 262.1
b16 B=256, M=197, H=1, K=88 | 869.9 | 262.2
f16 B=16, M=197, H=16, K=88 | 869.8 | 317.7
b16 B=16, M=197, H=16, K=88 | 872.4 | 319.0
f16 B=16, M=197, H=16, K=64 | 88.6 | 257.7
b16 B=16, M=197, H=16, K=64 | 89.6 | 259.8
f16 B=16, M=197, H=16, K=128 | 101.5 | 385.8
b16 B=16, M=197, H=16, K=128 | 102.2 | 387.1
f16 B=1, M=4096, H=160, K=128 | 9807.3 | 20343.8
b16 B=1, M=4096, H=160, K=128 | 10034.8 | 21406.5
f16 B=2, M=4096, H=160, K=128 | 19298.5 | 42595.1
b16 B=2, M=4096, H=160, K=128 | 19841.7 | 44399.3
f16 B=1, M=8192, H=160, K=128 | 37266.9 | 88426.6
b16 B=1, M=8192, H=160, K=128 | 38570.7 | 87103.1
f16 B=2, M=8192, H=160, K=128 | 74084.0 |
b16 B=2, M=8192, H=160, K=128 | 76617.8 |
f16 B=1024, M=82, H=8, K=64 | 461.1 | 1767.3
b16 B=1024, M=82, H=8, K=64 | 479.9 | 1862.5
f16 B=150, M=256, H=16, K=64 | 382.5 | 1725.6
b16 B=150, M=256, H=16, K=64 | 420.9 | 1760.7
f16 B=64, M=256, H=12, K=64 | 131.7 | 587.0
b16 B=64, M=256, H=12, K=64 | 145.7 | 597.4
f16 B=1, M=4096, H=16, K=40 | 30705.0 | 1937.8
b16 B=1, M=4096, H=16, K=40 | 30832.3 | 1990.1
f16 B=1, M=16384, H=16, K=40 | 423123.6 | 28845.9
b16 B=1, M=16384, H=16, K=40 | 420611.6 | 30167.0
f16 B=256, M=4096, H=16, K=64 | 118530.4 |
b16 B=256, M=4096, H=16, K=64 | 132843.9 |
f16 B=8, M=2048, H=20, K=128 | 2692.1 | 5704.5
b16 B=8, M=2048, H=20, K=128 | 2743.9 | 6044.7
f16 B=16, M=128, H=16, K=16 | 90.2 | 135.5
b16 B=16, M=128, H=16, K=16 | 87.0 | 136.7
f16 B=16, M=128, H=16, K=32 | 88.4 | 137.5
b16 B=16, M=128, H=16, K=32 | 91.0 | 138.1
f16 B=16, M=128, H=16, K=64 | 89.1 | 137.3
b16 B=16, M=128, H=16, K=64 | 90.3 | 137.7
f16 B=16, M=128, H=16, K=128 | 91.6 | 139.6
b16 B=16, M=128, H=16, K=128 | 90.6 | 140.7
f16 B=16, M=512, H=16, K=16 | 89.1 | 461.4
b16 B=16, M=512, H=16, K=16 | 99.3 | 553.7
f16 B=16, M=512, H=16, K=32 | 105.6 | 512.1
b16 B=16, M=512, H=16, K=32 | 123.2 | 594.4
f16 B=16, M=512, H=16, K=64 | 152.1 | 595.9
b16 B=16, M=512, H=16, K=64 | 169.4 | 613.0
f16 B=16, M=512, H=16, K=128 | 319.0 | 786.8
b16 B=16, M=512, H=16, K=128 | 328.0 | 804.9
f16 B=16, M=1024, H=16, K=16 | 292.1 | 1642.8
b16 B=16, M=1024, H=16, K=16 | 385.7 | 2025.9
f16 B=16, M=1024, H=16, K=32 | 369.5 | 1732.0
b16 B=16, M=1024, H=16, K=32 | 428.2 | 2127.2
f16 B=16, M=1024, H=16, K=64 | 518.4 | 2033.7
b16 B=16, M=1024, H=16, K=64 | 579.9 | 2064.6
f16 B=16, M=1024, H=16, K=128 | 1097.7 | 2410.3
b16 B=16, M=1024, H=16, K=128 | 1135.3 | 2490.3
f16 B=64, M=128, H=16, K=16 | 87.5 | 183.0
b16 B=64, M=128, H=16, K=16 | 89.6 | 184.9
f16 B=64, M=128, H=16, K=32 | 89.2 | 224.6
b16 B=64, M=128, H=16, K=32 | 88.2 | 225.5
f16 B=64, M=128, H=16, K=64 | 87.5 | 321.5
b16 B=64, M=128, H=16, K=64 | 87.2 | 322.9
f16 B=64, M=128, H=16, K=128 | 127.0 | 484.7
b16 B=64, M=128, H=16, K=128 | 129.4 | 486.1
f16 B=64, M=512, H=16, K=16 | 307.5 | 1727.1
b16 B=64, M=512, H=16, K=16 | 403.5 | 2094.0
f16 B=64, M=512, H=16, K=32 | 387.2 | 1888.1
b16 B=64, M=512, H=16, K=32 | 450.6 | 2250.3
f16 B=64, M=512, H=16, K=64 | 557.9 | 2233.9
b16 B=64, M=512, H=16, K=64 | 620.4 | 2294.9
f16 B=64, M=512, H=16, K=128 | 1205.0 | 3008.3
b16 B=64, M=512, H=16, K=128 | 1235.5 | 3072.8
f16 B=64, M=1024, H=16, K=16 | 1101.8 | 6422.7
b16 B=64, M=1024, H=16, K=16 | 1473.2 | 8024.6
f16 B=64, M=1024, H=16, K=32 | 1409.8 | 6744.7
b16 B=64, M=1024, H=16, K=32 | 1667.6 | 8387.5
f16 B=64, M=1024, H=16, K=64 | 2022.1 | 7966.1
b16 B=64, M=1024, H=16, K=64 | 2265.4 | 8127.2
f16 B=64, M=1024, H=16, K=128 | 4282.5 | 9524.9
b16 B=64, M=1024, H=16, K=128 | 4413.0 | 9842.0
Times are in microseconds (us).
[ attention (attn_bias=<class 'xformers.ops.LowerTriangularMask'>) ]
| optimized | vanilla
1 threads: ---------------------------------------------------
f16 B=384, M=197, H=1, K=88 | 3701.7 | 447.0
b16 B=384, M=197, H=1, K=88 | 1178.3 | 452.6
f16 B=384, M=197, H=1, K=80 | 120.2 | 421.6
b16 B=384, M=197, H=1, K=80 | 122.5 | 427.5
f16 B=384, M=197, H=1, K=64 | 87.3 | 370.6
b16 B=384, M=197, H=1, K=64 | 90.1 | 376.4
f16 B=1024, M=197, H=1, K=88 | 9586.4 | 1124.4
b16 B=1024, M=197, H=1, K=88 | 2981.6 | 1139.9
f16 B=1024, M=197, H=1, K=80 | 294.0 | 1059.9
b16 B=1024, M=197, H=1, K=80 | 300.3 | 1075.3
f16 B=1024, M=197, H=1, K=64 | 187.5 | 925.9
b16 B=1024, M=197, H=1, K=64 | 197.2 | 940.5
f16 B=512, M=197, H=1, K=80 | 153.5 | 549.2
b16 B=512, M=197, H=1, K=80 | 156.6 | 557.4
f16 B=32, M=197, H=16, K=80 | 157.4 | 647.5
b16 B=32, M=197, H=16, K=80 | 160.6 | 654.6
f16 B=32, M=197, H=16, K=64 | 103.8 | 565.6
b16 B=32, M=197, H=16, K=64 | 109.2 | 573.3
f16 B=32, M=197, H=16, K=128 | 157.7 | 810.9
b16 B=32, M=197, H=16, K=128 | 159.7 | 818.4
f16 B=256, M=197, H=1, K=88 | 2551.1 | 310.8
b16 B=256, M=197, H=1, K=88 | 813.3 | 314.9
f16 B=16, M=197, H=16, K=88 | 2554.5 | 367.9
b16 B=16, M=197, H=16, K=88 | 806.5 | 372.0
f16 B=16, M=197, H=16, K=64 | 90.3 | 307.1
b16 B=16, M=197, H=16, K=64 | 89.6 | 310.9
f16 B=16, M=197, H=16, K=128 | 87.8 | 435.4
b16 B=16, M=197, H=16, K=128 | 91.6 | 438.6
f16 B=1, M=4096, H=160, K=128 | 5176.5 | 37308.5
b16 B=1, M=4096, H=160, K=128 | 5330.5 | 37818.1
f16 B=2, M=4096, H=160, K=128 | 10225.6 | 76322.3
b16 B=2, M=4096, H=160, K=128 | 10547.6 | 77302.1
f16 B=1, M=8192, H=160, K=128 | 19289.3 | 152433.6
b16 B=1, M=8192, H=160, K=128 | 19922.2 | 148602.0
f16 B=2, M=8192, H=160, K=128 | 38313.9 |
b16 B=2, M=8192, H=160, K=128 | 39610.6 |
f16 B=1024, M=82, H=8, K=64 | 488.4 | 1979.9
b16 B=1024, M=82, H=8, K=64 | 515.8 | 2085.1
f16 B=150, M=256, H=16, K=64 | 333.7 | 2401.0
b16 B=150, M=256, H=16, K=64 | 349.7 | 2446.2
f16 B=64, M=256, H=12, K=64 | 118.8 | 805.3
b16 B=64, M=256, H=12, K=64 | 124.8 | 819.5
f16 B=1, M=4096, H=16, K=40 | 15004.4 | 3428.4
b16 B=1, M=4096, H=16, K=40 | 14890.7 | 3465.7
f16 B=1, M=16384, H=16, K=40 | 221390.1 | 55383.6
b16 B=1, M=16384, H=16, K=40 | 218101.2 | 55427.0
f16 B=256, M=4096, H=16, K=64 | 67757.8 |
b16 B=256, M=4096, H=16, K=64 | 74047.4 |
f16 B=8, M=2048, H=20, K=128 | 1487.8 | 9352.9
b16 B=8, M=2048, H=20, K=128 | 1534.3 | 9539.5
f16 B=16, M=128, H=16, K=16 | 92.2 | 146.0
b16 B=16, M=128, H=16, K=16 | 87.0 | 143.8
f16 B=16, M=128, H=16, K=32 | 87.2 | 144.5
b16 B=16, M=128, H=16, K=32 | 89.3 | 145.7
f16 B=16, M=128, H=16, K=64 | 90.6 | 144.4
b16 B=16, M=128, H=16, K=64 | 90.3 | 141.9
f16 B=16, M=128, H=16, K=128 | 87.2 | 160.0
b16 B=16, M=128, H=16, K=128 | 87.2 | 161.7
f16 B=16, M=512, H=16, K=16 | 90.4 | 715.1
b16 B=16, M=512, H=16, K=16 | 91.0 | 792.1
f16 B=16, M=512, H=16, K=32 | 88.5 | 766.6
b16 B=16, M=512, H=16, K=32 | 98.1 | 832.2
f16 B=16, M=512, H=16, K=64 | 120.3 | 881.1
b16 B=16, M=512, H=16, K=64 | 129.0 | 910.1
f16 B=16, M=512, H=16, K=128 | 225.5 | 1066.3
b16 B=16, M=512, H=16, K=128 | 231.9 | 1095.5
f16 B=16, M=1024, H=16, K=16 | 196.5 | 2658.0
b16 B=16, M=1024, H=16, K=16 | 253.2 | 3042.2
f16 B=16, M=1024, H=16, K=32 | 244.3 | 2749.6
b16 B=16, M=1024, H=16, K=32 | 291.8 | 3121.5
f16 B=16, M=1024, H=16, K=64 | 355.3 | 2963.7
b16 B=16, M=1024, H=16, K=64 | 384.7 | 3287.5
f16 B=16, M=1024, H=16, K=128 | 683.0 | 3534.2
b16 B=16, M=1024, H=16, K=128 | 705.2 | 3711.2
f16 B=64, M=128, H=16, K=16 | 87.6 | 245.8
b16 B=64, M=128, H=16, K=16 | 89.9 | 251.7
f16 B=64, M=128, H=16, K=32 | 88.4 | 294.2
b16 B=64, M=128, H=16, K=32 | 90.7 | 298.7
f16 B=64, M=128, H=16, K=64 | 92.0 | 391.1
b16 B=64, M=128, H=16, K=64 | 88.6 | 395.9
f16 B=64, M=128, H=16, K=128 | 131.0 | 557.0
b16 B=64, M=128, H=16, K=128 | 133.2 | 562.1
f16 B=64, M=512, H=16, K=16 | 227.6 | 2704.8
b16 B=64, M=512, H=16, K=16 | 289.3 | 3019.3
f16 B=64, M=512, H=16, K=32 | 285.1 | 2891.7
b16 B=64, M=512, H=16, K=32 | 336.8 | 3163.4
f16 B=64, M=512, H=16, K=64 | 414.3 | 3351.7
b16 B=64, M=512, H=16, K=64 | 444.2 | 3475.0
f16 B=64, M=512, H=16, K=128 | 829.5 | 4097.1
b16 B=64, M=512, H=16, K=128 | 854.6 | 4208.8
f16 B=64, M=1024, H=16, K=16 | 722.1 | 10479.3
b16 B=64, M=1024, H=16, K=16 | 931.6 | 12030.5
f16 B=64, M=1024, H=16, K=32 | 900.9 | 10856.5
b16 B=64, M=1024, H=16, K=32 | 1075.2 | 12361.2
f16 B=64, M=1024, H=16, K=64 | 1313.1 | 11684.0
b16 B=64, M=1024, H=16, K=64 | 1424.4 | 13002.0
f16 B=64, M=1024, H=16, K=128 | 2610.0 | 13969.2
b16 B=64, M=1024, H=16, K=128 | 2697.7 | 14687.4
Times are in microseconds (us).
[---- attention backward (attn_bias=<class 'NoneType'>) ----]
| optimized | vanilla
1 threads: --------------------------------------------------
f16 B=384, M=197, H=1, K=88 | 3702.8 | 820.9
b16 B=384, M=197, H=1, K=88 | 3701.9 | 823.4
f16 B=384, M=197, H=1, K=80 | 761.9 | 768.3
b16 B=384, M=197, H=1, K=80 | 762.5 | 769.8
f16 B=384, M=197, H=1, K=64 | 592.9 | 651.8
b16 B=384, M=197, H=1, K=64 | 372.7 | 652.9
f16 B=1024, M=197, H=1, K=88 | 9107.4 | 2103.4
b16 B=1024, M=197, H=1, K=88 | 9115.3 | 2105.4
f16 B=1024, M=197, H=1, K=80 | 1882.7 | 1958.4
b16 B=1024, M=197, H=1, K=80 | 1891.6 | 1958.3
f16 B=1024, M=197, H=1, K=64 | 916.6 | 1649.8
b16 B=1024, M=197, H=1, K=64 | 924.4 | 1649.1
f16 B=512, M=197, H=1, K=80 | 987.1 | 994.2
b16 B=512, M=197, H=1, K=80 | 990.8 | 997.0
f16 B=32, M=197, H=16, K=80 | 1037.3 | 1034.2
b16 B=32, M=197, H=16, K=80 | 1040.1 | 1035.8
f16 B=32, M=197, H=16, K=64 | 486.2 | 885.3
b16 B=32, M=197, H=16, K=64 | 488.7 | 887.6
f16 B=32, M=197, H=16, K=128 | 1165.0 | 1343.0
b16 B=32, M=197, H=16, K=128 | 1168.3 | 1345.8
f16 B=256, M=197, H=1, K=88 | 2360.9 | 576.3
b16 B=256, M=197, H=1, K=88 | 2361.7 | 576.9
f16 B=16, M=197, H=16, K=88 | 2377.4 | 599.7
b16 B=16, M=197, H=16, K=88 | 2376.8 | 600.9
f16 B=16, M=197, H=16, K=64 | 303.0 | 488.1
b16 B=16, M=197, H=16, K=64 | 330.5 | 488.6
f16 B=16, M=197, H=16, K=128 | 617.9 | 713.8
b16 B=16, M=197, H=16, K=128 | 619.7 | 716.1
f16 B=1, M=4096, H=160, K=128 | 41860.6 | 38927.7
b16 B=1, M=4096, H=160, K=128 | 41971.4 | 39833.9
f16 B=2, M=4096, H=160, K=128 | 83623.8 | 78708.1
b16 B=2, M=4096, H=160, K=128 | 84009.4 | 80592.3
f16 B=1, M=8192, H=160, K=128 | 160725.6 |
b16 B=1, M=8192, H=160, K=128 | 161001.2 |
f16 B=2, M=8192, H=160, K=128 | 321049.6 |
b16 B=2, M=8192, H=160, K=128 | 321800.6 |
f16 B=1024, M=82, H=8, K=64 | 2608.7 | 3609.7
b16 B=1024, M=82, H=8, K=64 | 2525.2 | 3783.9
f16 B=150, M=256, H=16, K=64 | 2552.5 | 3809.6
b16 B=150, M=256, H=16, K=64 | 2464.5 | 3838.1
f16 B=64, M=256, H=12, K=64 | 827.8 | 1257.5
b16 B=64, M=256, H=12, K=64 | 835.1 | 1269.2
f16 B=1, M=4096, H=16, K=40 | 43085.7 | 3556.9
b16 B=1, M=4096, H=16, K=40 | 42898.9 | 3595.8
f16 B=1, M=16384, H=16, K=40 | 664283.3 | 53870.2
b16 B=1, M=16384, H=16, K=40 | 664702.6 | 54308.0
f16 B=256, M=4096, H=16, K=64 | 503458.9 |
b16 B=256, M=4096, H=16, K=64 | 503461.7 |
f16 B=8, M=2048, H=20, K=128 | 11442.9 | 10536.5
b16 B=8, M=2048, H=20, K=128 | 11362.9 | 10790.8
f16 B=16, M=128, H=16, K=16 | 349.5 | 311.2
b16 B=16, M=128, H=16, K=16 | 303.8 | 314.0
f16 B=16, M=128, H=16, K=32 | 316.6 | 340.6
b16 B=16, M=128, H=16, K=32 | 328.7 | 337.9
f16 B=16, M=128, H=16, K=64 | 318.3 | 333.9
b16 B=16, M=128, H=16, K=64 | 332.7 | 335.4
f16 B=16, M=128, H=16, K=128 | 452.5 | 331.2
b16 B=16, M=128, H=16, K=128 | 329.1 | 335.9
f16 B=16, M=512, H=16, K=16 | 527.1 | 982.9
b16 B=16, M=512, H=16, K=16 | 322.8 | 1078.9
f16 B=16, M=512, H=16, K=32 | 680.9 | 1090.2
b16 B=16, M=512, H=16, K=32 | 487.4 | 1179.3
f16 B=16, M=512, H=16, K=64 | 1016.7 | 1276.3
b16 B=16, M=512, H=16, K=64 | 845.2 | 1295.1
f16 B=16, M=512, H=16, K=128 | 1891.7 | 1712.6
b16 B=16, M=512, H=16, K=128 | 1766.5 | 1744.0
f16 B=16, M=1024, H=16, K=16 | 1257.4 | 3532.6
b16 B=16, M=1024, H=16, K=16 | 1066.4 | 3953.2
f16 B=16, M=1024, H=16, K=32 | 1770.3 | 3731.3
b16 B=16, M=1024, H=16, K=32 | 1615.7 | 4158.6
f16 B=16, M=1024, H=16, K=64 | 2686.2 | 4281.1
b16 B=16, M=1024, H=16, K=64 | 2545.2 | 4348.9
f16 B=16, M=1024, H=16, K=128 | 5407.5 | 5109.9
b16 B=16, M=1024, H=16, K=128 | 5312.2 | 5248.8
f16 B=64, M=128, H=16, K=16 | 304.3 | 365.2
b16 B=64, M=128, H=16, K=16 | 325.1 | 372.8
f16 B=64, M=128, H=16, K=32 | 304.3 | 465.8
b16 B=64, M=128, H=16, K=32 | 301.5 | 472.4
f16 B=64, M=128, H=16, K=64 | 464.3 | 679.8
b16 B=64, M=128, H=16, K=64 | 469.2 | 681.4
f16 B=64, M=128, H=16, K=128 | 980.1 | 1076.0
b16 B=64, M=128, H=16, K=128 | 987.5 | 1079.2
f16 B=64, M=512, H=16, K=16 | 1109.7 | 3705.7
b16 B=64, M=512, H=16, K=16 | 1111.9 | 4078.7
f16 B=64, M=512, H=16, K=32 | 1678.8 | 4135.2
b16 B=64, M=512, H=16, K=32 | 1684.5 | 4504.2
f16 B=64, M=512, H=16, K=64 | 2982.7 | 4906.5
b16 B=64, M=512, H=16, K=64 | 3001.9 | 4992.0
f16 B=64, M=512, H=16, K=128 | 6641.2 | 6644.6
b16 B=64, M=512, H=16, K=128 | 6697.3 | 6756.1
f16 B=64, M=1024, H=16, K=16 | 3610.9 | 13882.0
b16 B=64, M=1024, H=16, K=16 | 3614.8 | 15555.3
f16 B=64, M=1024, H=16, K=32 | 6138.0 | 14767.7
b16 B=64, M=1024, H=16, K=32 | 6147.9 | 16450.9
f16 B=64, M=1024, H=16, K=64 | 9813.8 | 16866.6
b16 B=64, M=1024, H=16, K=64 | 9848.6 | 17160.6
f16 B=64, M=1024, H=16, K=128 | 20556.5 | 20371.7
b16 B=64, M=1024, H=16, K=128 | 20674.3 | 20916.7
Times are in microseconds (us).
[ attention backward (attn_bias=<class 'xformers.ops.LowerTriangularMask'>) ]
| optimized | vanilla
1 threads: --------------------------------------------------
f16 B=384, M=197, H=1, K=88 | 9480.6 | 822.5
b16 B=384, M=197, H=1, K=88 | 9355.1 | 823.8
f16 B=384, M=197, H=1, K=80 | 669.6 | 768.6
b16 B=384, M=197, H=1, K=80 | 674.6 | 770.2
f16 B=384, M=197, H=1, K=64 | 542.3 | 651.9
b16 B=384, M=197, H=1, K=64 | 351.1 | 653.8
f16 B=1024, M=197, H=1, K=88 | 23565.1 | 2105.1
b16 B=1024, M=197, H=1, K=88 | 23556.3 | 2105.4
f16 B=1024, M=197, H=1, K=80 | 1646.1 | 1961.7
b16 B=1024, M=197, H=1, K=80 | 1651.9 | 1958.3
f16 B=1024, M=197, H=1, K=64 | 861.8 | 1649.0
b16 B=1024, M=197, H=1, K=64 | 869.0 | 1648.1
f16 B=512, M=197, H=1, K=80 | 863.6 | 995.4
b16 B=512, M=197, H=1, K=80 | 866.2 | 997.1
f16 B=32, M=197, H=16, K=80 | 867.8 | 1034.4
b16 B=32, M=197, H=16, K=80 | 869.5 | 1036.2
f16 B=32, M=197, H=16, K=64 | 456.3 | 883.7
b16 B=32, M=197, H=16, K=64 | 458.7 | 885.6
f16 B=32, M=197, H=16, K=128 | 1042.0 | 1340.5
b16 B=32, M=197, H=16, K=128 | 1045.1 | 1343.6
f16 B=256, M=197, H=1, K=88 | 6342.1 | 577.0
b16 B=256, M=197, H=1, K=88 | 6344.3 | 578.3
f16 B=16, M=197, H=16, K=88 | 6335.6 | 600.3
b16 B=16, M=197, H=16, K=88 | 6341.5 | 601.6
f16 B=16, M=197, H=16, K=64 | 330.3 | 488.4
b16 B=16, M=197, H=16, K=64 | 325.5 | 490.0
f16 B=16, M=197, H=16, K=128 | 557.4 | 715.2
b16 B=16, M=197, H=16, K=128 | 557.8 | 717.4
f16 B=1, M=4096, H=160, K=128 | 26152.2 | 38914.3
b16 B=1, M=4096, H=160, K=128 | 26166.4 | 39924.4
f16 B=2, M=4096, H=160, K=128 | 51711.2 | 78726.9
b16 B=2, M=4096, H=160, K=128 | 51925.3 | 80711.7
f16 B=1, M=8192, H=160, K=128 | 92696.3 |
b16 B=1, M=8192, H=160, K=128 | 92960.2 |
f16 B=2, M=8192, H=160, K=128 | 184624.7 |
b16 B=2, M=8192, H=160, K=128 | 185330.2 |
f16 B=1024, M=82, H=8, K=64 | 2724.8 | 3608.8
b16 B=1024, M=82, H=8, K=64 | 2642.1 | 3784.4
f16 B=150, M=256, H=16, K=64 | 2379.5 | 3803.7
b16 B=150, M=256, H=16, K=64 | 2287.7 | 3832.2
f16 B=64, M=256, H=12, K=64 | 774.1 | 1255.6
b16 B=64, M=256, H=12, K=64 | 781.0 | 1270.5
f16 B=1, M=4096, H=16, K=40 | 6799.9 | 3563.7
b16 B=1, M=4096, H=16, K=40 | 6624.0 | 3598.3
f16 B=1, M=16384, H=16, K=40 | 94594.6 | 53902.5
b16 B=1, M=16384, H=16, K=40 | 94692.4 | 54354.3
f16 B=256, M=4096, H=16, K=64 | 274333.4 |
b16 B=256, M=4096, H=16, K=64 | 274964.5 |
f16 B=8, M=2048, H=20, K=128 | 7844.4 | 10545.5
b16 B=8, M=2048, H=20, K=128 | 7764.3 | 10795.4
f16 B=16, M=128, H=16, K=16 | 334.9 | 293.7
b16 B=16, M=128, H=16, K=16 | 326.3 | 314.7
f16 B=16, M=128, H=16, K=32 | 354.5 | 319.1
b16 B=16, M=128, H=16, K=32 | 304.6 | 291.4
f16 B=16, M=128, H=16, K=64 | 350.7 | 312.6
b16 B=16, M=128, H=16, K=64 | 331.6 | 301.8
f16 B=16, M=128, H=16, K=128 | 486.3 | 313.5
b16 B=16, M=128, H=16, K=128 | 326.8 | 314.4
f16 B=16, M=512, H=16, K=16 | 456.9 | 986.5
b16 B=16, M=512, H=16, K=16 | 306.0 | 1080.0
f16 B=16, M=512, H=16, K=32 | 566.6 | 1091.4
b16 B=16, M=512, H=16, K=32 | 393.9 | 1179.2
f16 B=16, M=512, H=16, K=64 | 841.5 | 1276.5
b16 B=16, M=512, H=16, K=64 | 669.4 | 1294.9
f16 B=16, M=512, H=16, K=128 | 1590.8 | 1713.6
b16 B=16, M=512, H=16, K=128 | 1486.2 | 1740.0
f16 B=16, M=1024, H=16, K=16 | 858.6 | 3529.7
b16 B=16, M=1024, H=16, K=16 | 686.6 | 3951.6
f16 B=16, M=1024, H=16, K=32 | 1233.9 | 3733.7
b16 B=16, M=1024, H=16, K=32 | 1059.3 | 4154.4
f16 B=16, M=1024, H=16, K=64 | 1875.2 | 4281.8
b16 B=16, M=1024, H=16, K=64 | 1754.2 | 4347.6
f16 B=16, M=1024, H=16, K=128 | 4062.9 | 5105.1
b16 B=16, M=1024, H=16, K=128 | 3983.0 | 5247.9
f16 B=64, M=128, H=16, K=16 | 304.4 | 368.6
b16 B=64, M=128, H=16, K=16 | 301.9 | 376.4
f16 B=64, M=128, H=16, K=32 | 303.9 | 466.2
b16 B=64, M=128, H=16, K=32 | 327.3 | 471.8
f16 B=64, M=128, H=16, K=64 | 475.8 | 682.3
b16 B=64, M=128, H=16, K=64 | 479.3 | 679.5
f16 B=64, M=128, H=16, K=128 | 1041.7 | 1076.2
b16 B=64, M=128, H=16, K=128 | 1047.8 | 1077.4
f16 B=64, M=512, H=16, K=16 | 912.1 | 3704.3
b16 B=64, M=512, H=16, K=16 | 917.2 | 4082.8
f16 B=64, M=512, H=16, K=32 | 1458.2 | 4133.9
b16 B=64, M=512, H=16, K=32 | 1461.2 | 4501.9
f16 B=64, M=512, H=16, K=64 | 2466.8 | 4903.1
b16 B=64, M=512, H=16, K=64 | 2492.7 | 4993.7
f16 B=64, M=512, H=16, K=128 | 5545.0 | 6642.9
b16 B=64, M=512, H=16, K=128 | 5595.6 | 6754.2
f16 B=64, M=1024, H=16, K=16 | 2608.6 | 13864.4
b16 B=64, M=1024, H=16, K=16 | 2610.5 | 15560.3
f16 B=64, M=1024, H=16, K=32 | 4025.8 | 14755.8
b16 B=64, M=1024, H=16, K=32 | 4032.1 | 16463.7
f16 B=64, M=1024, H=16, K=64 | 6658.9 | 16880.7
b16 B=64, M=1024, H=16, K=64 | 6722.0 | 17150.6
f16 B=64, M=1024, H=16, K=128 | 15345.5 | 20371.0
b16 B=64, M=1024, H=16, K=128 | 15438.5 | 20920.1
Times are in microseconds (us).
[ attention backward (attn_bias=<class 'xformers.ops.LowerTriangularMask'>) ]
| optimized | fctls_bflsh
1 threads: ------------------------------------------------------
f16 B=384, M=197, H=1, K=64 | 542.1 | 241.0
b16 B=384, M=197, H=1, K=64 | 344.0 | 236.7
f16 B=1024, M=197, H=1, K=64 | 852.4 | 539.2
b16 B=1024, M=197, H=1, K=64 | 857.6 | 544.6
f16 B=32, M=197, H=16, K=64 | 449.0 | 280.6
b16 B=32, M=197, H=16, K=64 | 451.6 | 283.2
f16 B=32, M=197, H=16, K=128 | 1161.4 | 517.8
b16 B=32, M=197, H=16, K=128 | 1030.0 | 518.8
f16 B=16, M=197, H=16, K=64 | 312.1 | 214.0
b16 B=16, M=197, H=16, K=64 | 309.8 | 213.7
f16 B=16, M=197, H=16, K=128 | 547.4 | 296.6
b16 B=16, M=197, H=16, K=128 | 548.1 | 298.1
f16 B=1, M=4096, H=160, K=128 | 25831.9 | 31365.0
b16 B=1, M=4096, H=160, K=128 | 25811.2 | 31382.5
f16 B=2, M=4096, H=160, K=128 | 51084.6 | 48413.3
b16 B=2, M=4096, H=160, K=128 | 51265.2 | 48371.8
f16 B=1, M=8192, H=160, K=128 | 91356.3 | 121976.0
b16 B=1, M=8192, H=160, K=128 | 91520.0 | 122016.8
f16 B=2, M=8192, H=160, K=128 | 181984.4 | 187070.0
b16 B=2, M=8192, H=160, K=128 | 182646.4 | 187302.9
f16 B=1024, M=82, H=8, K=64 | 2703.5 | 1502.2
b16 B=1024, M=82, H=8, K=64 | 2615.9 | 1511.7
f16 B=150, M=256, H=16, K=64 | 2380.2 | 1505.7
b16 B=150, M=256, H=16, K=64 | 2264.4 | 1521.6
f16 B=64, M=256, H=12, K=64 | 768.4 | 526.4
b16 B=64, M=256, H=12, K=64 | 774.2 | 530.1
f16 B=8, M=2048, H=20, K=128 | 7798.2 | 8299.4
b16 B=8, M=2048, H=20, K=128 | 7678.0 | 8323.0
f16 B=16, M=128, H=16, K=16 | 322.9 | 212.5
b16 B=16, M=128, H=16, K=16 | 311.7 | 192.5
f16 B=16, M=128, H=16, K=32 | 359.6 | 212.4
b16 B=16, M=128, H=16, K=32 | 333.6 | 194.5
f16 B=16, M=128, H=16, K=64 | 330.3 | 196.0
b16 B=16, M=128, H=16, K=64 | 314.8 | 196.9
f16 B=16, M=128, H=16, K=128 | 489.5 | 215.7
b16 B=16, M=128, H=16, K=128 | 343.9 | 215.5
f16 B=16, M=512, H=16, K=16 | 463.5 | 261.1
b16 B=16, M=512, H=16, K=16 | 318.1 | 263.3
f16 B=16, M=512, H=16, K=32 | 573.5 | 337.3
b16 B=16, M=512, H=16, K=32 | 391.6 | 339.6
f16 B=16, M=512, H=16, K=64 | 829.6 | 517.0
b16 B=16, M=512, H=16, K=64 | 666.4 | 521.2
f16 B=16, M=512, H=16, K=128 | 1608.1 | 1034.4
b16 B=16, M=512, H=16, K=128 | 1472.1 | 1032.4
f16 B=16, M=1024, H=16, K=16 | 866.3 | 792.3
b16 B=16, M=1024, H=16, K=16 | 684.4 | 793.2
f16 B=16, M=1024, H=16, K=32 | 1269.0 | 1022.8
b16 B=16, M=1024, H=16, K=32 | 1056.2 | 1023.4
f16 B=16, M=1024, H=16, K=64 | 1885.2 | 1563.7
b16 B=16, M=1024, H=16, K=64 | 1749.4 | 1564.6
f16 B=16, M=1024, H=16, K=128 | 4052.5 | 3428.0
b16 B=16, M=1024, H=16, K=128 | 3929.8 | 3432.8
f16 B=64, M=128, H=16, K=16 | 313.7 | 218.0
b16 B=64, M=128, H=16, K=16 | 336.9 | 217.6
f16 B=64, M=128, H=16, K=32 | 317.4 | 212.5
b16 B=64, M=128, H=16, K=32 | 318.1 | 208.0
f16 B=64, M=128, H=16, K=64 | 470.3 | 284.8
b16 B=64, M=128, H=16, K=64 | 474.6 | 286.9
f16 B=64, M=128, H=16, K=128 | 1024.8 | 504.8
b16 B=64, M=128, H=16, K=128 | 1030.2 | 507.4
f16 B=64, M=512, H=16, K=16 | 909.5 | 924.7
b16 B=64, M=512, H=16, K=16 | 913.6 | 930.5
f16 B=64, M=512, H=16, K=32 | 1454.8 | 1176.4
b16 B=64, M=512, H=16, K=32 | 1459.1 | 1181.8
f16 B=64, M=512, H=16, K=64 | 2460.3 | 1752.1
b16 B=64, M=512, H=16, K=64 | 2485.5 | 1773.0
f16 B=64, M=512, H=16, K=128 | 5503.4 | 3564.6
b16 B=64, M=512, H=16, K=128 | 5557.4 | 3592.5
f16 B=64, M=1024, H=16, K=16 | 2599.1 | 2866.5
b16 B=64, M=1024, H=16, K=16 | 2605.9 | 2868.9
f16 B=64, M=1024, H=16, K=32 | 4017.3 | 3577.7
b16 B=64, M=1024, H=16, K=32 | 4022.4 | 3592.0
f16 B=64, M=1024, H=16, K=64 | 6648.5 | 5349.7
b16 B=64, M=1024, H=16, K=64 | 6716.9 | 5374.7
f16 B=64, M=1024, H=16, K=128 | 15206.5 | 11777.5
b16 B=64, M=1024, H=16, K=128 | 15313.3 | 11814.0
Times are in microseconds (us).
Performance Compared to MemoryEfficientAttentionCutlassFwdFlashBwOp
[----------- attention (attn_bias=<class 'NoneType'>) ----------]
| optimized | fctls_bflsh
1 threads: ------------------------------------------------------
f16 B=384, M=197, H=1, K=64 | 89.9 | 88.9
b16 B=384, M=197, H=1, K=64 | 94.2 | 88.8
f16 B=1024, M=197, H=1, K=64 | 213.5 | 215.1
b16 B=1024, M=197, H=1, K=64 | 221.9 | 215.1
f16 B=32, M=197, H=16, K=64 | 114.0 | 113.8
b16 B=32, M=197, H=16, K=64 | 120.0 | 113.8
f16 B=32, M=197, H=16, K=128 | 193.3 | 168.6
b16 B=32, M=197, H=16, K=128 | 194.6 | 166.9
f16 B=16, M=197, H=16, K=64 | 85.2 | 60.7
b16 B=16, M=197, H=16, K=64 | 87.2 | 60.7
f16 B=16, M=197, H=16, K=128 | 101.6 | 88.6
b16 B=16, M=197, H=16, K=128 | 102.0 | 88.0
f16 B=1, M=4096, H=160, K=128 | 9816.7 | 15435.3
b16 B=1, M=4096, H=160, K=128 | 10046.1 | 15280.9
f16 B=2, M=4096, H=160, K=128 | 19357.5 | 30760.0
b16 B=2, M=4096, H=160, K=128 | 19869.8 | 30546.3
f16 B=1, M=8192, H=160, K=128 | 37248.5 | 61367.7
b16 B=1, M=8192, H=160, K=128 | 38471.5 | 61015.5
f16 B=2, M=8192, H=160, K=128 | 73817.0 | 123123.0
b16 B=2, M=8192, H=160, K=128 | 76761.6 | 121918.5
f16 B=1024, M=82, H=8, K=64 | 459.6 | 455.0
b16 B=1024, M=82, H=8, K=64 | 477.1 | 455.1
f16 B=150, M=256, H=16, K=64 | 383.4 | 521.7
b16 B=150, M=256, H=16, K=64 | 419.8 | 518.6
f16 B=64, M=256, H=12, K=64 | 131.4 | 176.2
b16 B=64, M=256, H=12, K=64 | 145.1 | 173.2
f16 B=256, M=4096, H=16, K=64 | 118450.0 | 191728.6
b16 B=256, M=4096, H=16, K=64 | 131960.6 | 190391.4
f16 B=8, M=2048, H=20, K=128 | 2694.7 | 3380.7
b16 B=8, M=2048, H=20, K=128 | 2738.6 | 3332.9
f16 B=16, M=128, H=16, K=16 | 90.9 | 31.2
b16 B=16, M=128, H=16, K=16 | 89.4 | 30.2
f16 B=16, M=128, H=16, K=32 | 87.6 | 30.7
b16 B=16, M=128, H=16, K=32 | 90.0 | 30.7
f16 B=16, M=128, H=16, K=64 | 88.7 | 30.2
b16 B=16, M=128, H=16, K=64 | 89.4 | 30.2
f16 B=16, M=128, H=16, K=128 | 87.2 | 35.8
b16 B=16, M=128, H=16, K=128 | 89.3 | 35.8
f16 B=16, M=512, H=16, K=16 | 89.2 | 180.0
b16 B=16, M=512, H=16, K=16 | 99.6 | 179.9
f16 B=16, M=512, H=16, K=32 | 105.3 | 186.2
b16 B=16, M=512, H=16, K=32 | 123.0 | 186.1
f16 B=16, M=512, H=16, K=64 | 152.1 | 213.7
b16 B=16, M=512, H=16, K=64 | 169.3 | 213.7
f16 B=16, M=512, H=16, K=128 | 318.6 | 367.5
b16 B=16, M=512, H=16, K=128 | 327.5 | 364.1
f16 B=16, M=1024, H=16, K=16 | 291.8 | 686.6
b16 B=16, M=1024, H=16, K=16 | 384.4 | 689.4
f16 B=16, M=1024, H=16, K=32 | 369.7 | 693.3
b16 B=16, M=1024, H=16, K=32 | 427.3 | 695.3
f16 B=16, M=1024, H=16, K=64 | 518.2 | 802.4
b16 B=16, M=1024, H=16, K=64 | 579.3 | 794.1
f16 B=16, M=1024, H=16, K=128 | 1098.6 | 1383.8
b16 B=16, M=1024, H=16, K=128 | 1135.7 | 1371.4
f16 B=64, M=128, H=16, K=16 | 85.6 | 53.8
b16 B=64, M=128, H=16, K=16 | 85.7 | 53.8
f16 B=64, M=128, H=16, K=32 | 89.5 | 58.3
b16 B=64, M=128, H=16, K=32 | 87.5 | 58.4
f16 B=64, M=128, H=16, K=64 | 86.1 | 69.5
b16 B=64, M=128, H=16, K=64 | 85.6 | 69.3
f16 B=64, M=128, H=16, K=128 | 127.2 | 121.0
b16 B=64, M=128, H=16, K=128 | 129.7 | 119.1
f16 B=64, M=512, H=16, K=16 | 307.1 | 703.0
b16 B=64, M=512, H=16, K=16 | 403.4 | 703.5
f16 B=64, M=512, H=16, K=32 | 387.1 | 712.1
b16 B=64, M=512, H=16, K=32 | 449.8 | 711.7
f16 B=64, M=512, H=16, K=64 | 558.4 | 821.8
b16 B=64, M=512, H=16, K=64 | 619.9 | 818.0
f16 B=64, M=512, H=16, K=128 | 1200.5 | 1451.8
b16 B=64, M=512, H=16, K=128 | 1235.0 | 1434.0
f16 B=64, M=1024, H=16, K=16 | 1101.1 | 2692.5
b16 B=64, M=1024, H=16, K=16 | 1472.9 | 2691.6
f16 B=64, M=1024, H=16, K=32 | 1409.8 | 2720.0
b16 B=64, M=1024, H=16, K=32 | 1669.5 | 2715.8
f16 B=64, M=1024, H=16, K=64 | 2025.4 | 3136.1
b16 B=64, M=1024, H=16, K=64 | 2263.5 | 3105.2
f16 B=64, M=1024, H=16, K=128 | 4279.5 | 5500.2
b16 B=64, M=1024, H=16, K=128 | 4413.0 | 5406.7
Times are in microseconds (us).
[ attention (attn_bias=<class 'xformers.ops.LowerTriangularMask'>) ]
| optimized | fctls_bflsh
1 threads: ------------------------------------------------------
f16 B=384, M=197, H=1, K=64 | 89.1 | 66.3
b16 B=384, M=197, H=1, K=64 | 87.5 | 66.3
f16 B=1024, M=197, H=1, K=64 | 187.0 | 153.3
b16 B=1024, M=197, H=1, K=64 | 196.9 | 153.2
f16 B=32, M=197, H=16, K=64 | 103.6 | 84.7
b16 B=32, M=197, H=16, K=64 | 109.5 | 84.8
f16 B=32, M=197, H=16, K=128 | 157.6 | 126.1
b16 B=32, M=197, H=16, K=128 | 159.3 | 124.3
f16 B=16, M=197, H=16, K=64 | 88.5 | 46.3
b16 B=16, M=197, H=16, K=64 | 86.2 | 46.3
f16 B=16, M=197, H=16, K=128 | 86.1 | 68.3
b16 B=16, M=197, H=16, K=128 | 90.7 | 68.3
f16 B=1, M=4096, H=160, K=128 | 5167.6 | 7921.9
b16 B=1, M=4096, H=160, K=128 | 5335.3 | 7870.1
f16 B=2, M=4096, H=160, K=128 | 10226.7 | 15736.6
b16 B=2, M=4096, H=160, K=128 | 10540.9 | 15639.6
f16 B=1, M=8192, H=160, K=128 | 19262.0 | 31230.1
b16 B=1, M=8192, H=160, K=128 | 19908.8 | 30905.6
f16 B=2, M=8192, H=160, K=128 | 38306.6 | 62229.2
b16 B=2, M=8192, H=160, K=128 | 39533.3 | 61578.2
f16 B=1024, M=82, H=8, K=64 | 486.7 | 375.9
b16 B=1024, M=82, H=8, K=64 | 514.7 | 376.0
f16 B=150, M=256, H=16, K=64 | 333.8 | 363.1
b16 B=150, M=256, H=16, K=64 | 349.4 | 359.8
f16 B=64, M=256, H=12, K=64 | 118.3 | 124.9
b16 B=64, M=256, H=12, K=64 | 124.5 | 123.9
f16 B=256, M=4096, H=16, K=64 | 67841.8 | 99189.2
b16 B=256, M=4096, H=16, K=64 | 73773.6 | 97542.4
f16 B=8, M=2048, H=20, K=128 | 1487.1 | 1851.1
b16 B=8, M=2048, H=20, K=128 | 1533.2 | 1818.1
f16 B=16, M=128, H=16, K=16 | 87.7 | 30.3
b16 B=16, M=128, H=16, K=16 | 85.4 | 30.6
f16 B=16, M=128, H=16, K=32 | 86.8 | 30.7
b16 B=16, M=128, H=16, K=32 | 89.3 | 30.6
f16 B=16, M=128, H=16, K=64 | 85.6 | 30.8
b16 B=16, M=128, H=16, K=64 | 88.8 | 30.5
f16 B=16, M=128, H=16, K=128 | 86.6 | 34.3
b16 B=16, M=128, H=16, K=128 | 87.6 | 34.4
f16 B=16, M=512, H=16, K=16 | 86.6 | 119.7
b16 B=16, M=512, H=16, K=16 | 86.1 | 119.7
f16 B=16, M=512, H=16, K=32 | 86.2 | 124.0
b16 B=16, M=512, H=16, K=32 | 98.0 | 124.0
f16 B=16, M=512, H=16, K=64 | 120.6 | 142.7
b16 B=16, M=512, H=16, K=64 | 129.1 | 142.8
f16 B=16, M=512, H=16, K=128 | 225.3 | 241.0
b16 B=16, M=512, H=16, K=128 | 231.9 | 238.4
f16 B=16, M=1024, H=16, K=16 | 196.4 | 399.8
b16 B=16, M=1024, H=16, K=16 | 252.5 | 399.4
f16 B=16, M=1024, H=16, K=32 | 244.2 | 405.7
b16 B=16, M=1024, H=16, K=32 | 291.7 | 405.3
f16 B=16, M=1024, H=16, K=64 | 356.4 | 464.3
b16 B=16, M=1024, H=16, K=64 | 385.0 | 463.6
f16 B=16, M=1024, H=16, K=128 | 683.6 | 805.3
b16 B=16, M=1024, H=16, K=128 | 704.4 | 794.6
f16 B=64, M=128, H=16, K=16 | 85.7 | 47.1
b16 B=64, M=128, H=16, K=16 | 86.2 | 47.1
f16 B=64, M=128, H=16, K=32 | 89.4 | 50.2
b16 B=64, M=128, H=16, K=32 | 88.1 | 50.2
f16 B=64, M=128, H=16, K=64 | 85.8 | 61.1
b16 B=64, M=128, H=16, K=64 | 85.6 | 61.1
f16 B=64, M=128, H=16, K=128 | 131.1 | 109.6
b16 B=64, M=128, H=16, K=128 | 133.4 | 108.7
f16 B=64, M=512, H=16, K=16 | 227.3 | 428.1
b16 B=64, M=512, H=16, K=16 | 289.1 | 428.1
f16 B=64, M=512, H=16, K=32 | 285.0 | 434.4
b16 B=64, M=512, H=16, K=32 | 336.7 | 434.3
f16 B=64, M=512, H=16, K=64 | 415.5 | 507.7
b16 B=64, M=512, H=16, K=64 | 444.6 | 502.7
f16 B=64, M=512, H=16, K=128 | 829.4 | 919.7
b16 B=64, M=512, H=16, K=128 | 854.1 | 910.8
f16 B=64, M=1024, H=16, K=16 | 722.3 | 1497.8
b16 B=64, M=1024, H=16, K=16 | 931.8 | 1495.7
f16 B=64, M=1024, H=16, K=32 | 900.6 | 1512.9
b16 B=64, M=1024, H=16, K=32 | 1075.4 | 1513.4
f16 B=64, M=1024, H=16, K=64 | 1316.9 | 1764.7
b16 B=64, M=1024, H=16, K=64 | 1424.1 | 1739.0
f16 B=64, M=1024, H=16, K=128 | 2607.3 | 3139.7
b16 B=64, M=1024, H=16, K=128 | 2692.0 | 3106.5
Times are in microseconds (us).
[------ attention backward (attn_bias=<class 'NoneType'>) ------]
| optimized | fctls_bflsh
1 threads: ------------------------------------------------------
f16 B=384, M=197, H=1, K=64 | 545.1 | 236.0
b16 B=384, M=197, H=1, K=64 | 366.9 | 238.5
f16 B=1024, M=197, H=1, K=64 | 908.5 | 532.5
b16 B=1024, M=197, H=1, K=64 | 915.2 | 531.5
f16 B=32, M=197, H=16, K=64 | 477.8 | 276.7
b16 B=32, M=197, H=16, K=64 | 481.0 | 277.0
f16 B=32, M=197, H=16, K=128 | 1307.9 | 634.5
b16 B=32, M=197, H=16, K=128 | 1156.2 | 635.5
f16 B=16, M=197, H=16, K=64 | 333.1 | 213.3
b16 B=16, M=197, H=16, K=64 | 310.0 | 212.3
f16 B=16, M=197, H=16, K=128 | 597.0 | 365.3
b16 B=16, M=197, H=16, K=128 | 598.4 | 366.0
f16 B=1, M=4096, H=160, K=128 | 41711.0 | 54006.5
b16 B=1, M=4096, H=160, K=128 | 41674.5 | 53975.9
f16 B=2, M=4096, H=160, K=128 | 82813.6 | 82572.8
b16 B=2, M=4096, H=160, K=128 | 83121.0 | 82540.1
f16 B=1, M=8192, H=160, K=128 | 158913.4 | 213065.0
b16 B=1, M=8192, H=160, K=128 | 159351.2 | 213233.8
f16 B=2, M=8192, H=160, K=128 | 318080.6 | 324743.8
b16 B=2, M=8192, H=160, K=128 | 318049.8 | 324995.7
f16 B=1024, M=82, H=8, K=64 | 2628.2 | 1480.0
b16 B=1024, M=82, H=8, K=64 | 2514.7 | 1502.1
f16 B=150, M=256, H=16, K=64 | 2529.2 | 1485.6
b16 B=150, M=256, H=16, K=64 | 2434.0 | 1485.3
f16 B=64, M=256, H=12, K=64 | 824.4 | 520.0
b16 B=64, M=256, H=12, K=64 | 831.4 | 517.1
f16 B=8, M=2048, H=20, K=128 | 11325.1 | 13894.0
b16 B=8, M=2048, H=20, K=128 | 11260.2 | 13895.4
f16 B=16, M=128, H=16, K=16 | 323.3 | 193.0
b16 B=16, M=128, H=16, K=16 | 339.6 | 213.9
f16 B=16, M=128, H=16, K=32 | 368.4 | 215.0
b16 B=16, M=128, H=16, K=32 | 333.9 | 210.6
f16 B=16, M=128, H=16, K=64 | 328.8 | 196.2
b16 B=16, M=128, H=16, K=64 | 314.8 | 197.2
f16 B=16, M=128, H=16, K=128 | 496.4 | 222.7
b16 B=16, M=128, H=16, K=128 | 334.9 | 215.7
f16 B=16, M=512, H=16, K=16 | 536.8 | 320.3
b16 B=16, M=512, H=16, K=16 | 340.9 | 323.6
f16 B=16, M=512, H=16, K=32 | 660.7 | 422.7
b16 B=16, M=512, H=16, K=32 | 479.8 | 424.9
f16 B=16, M=512, H=16, K=64 | 985.6 | 670.5
b16 B=16, M=512, H=16, K=64 | 824.3 | 672.4
f16 B=16, M=512, H=16, K=128 | 1864.8 | 1504.7
b16 B=16, M=512, H=16, K=128 | 1751.5 | 1508.3
f16 B=16, M=1024, H=16, K=16 | 1255.1 | 1238.7
b16 B=16, M=1024, H=16, K=16 | 1050.9 | 1241.7
f16 B=16, M=1024, H=16, K=32 | 1798.1 | 1588.3
b16 B=16, M=1024, H=16, K=32 | 1610.2 | 1596.3
f16 B=16, M=1024, H=16, K=64 | 2691.8 | 2304.6
b16 B=16, M=1024, H=16, K=64 | 2548.6 | 2310.8
f16 B=16, M=1024, H=16, K=128 | 5376.4 | 5464.7
b16 B=16, M=1024, H=16, K=128 | 5268.8 | 5473.3
f16 B=64, M=128, H=16, K=16 | 328.3 | 194.6
b16 B=64, M=128, H=16, K=16 | 320.7 | 194.9
f16 B=64, M=128, H=16, K=32 | 317.8 | 221.3
b16 B=64, M=128, H=16, K=32 | 315.1 | 199.9
f16 B=64, M=128, H=16, K=64 | 461.0 | 281.6
b16 B=64, M=128, H=16, K=64 | 465.7 | 284.7
f16 B=64, M=128, H=16, K=128 | 967.1 | 491.4
b16 B=64, M=128, H=16, K=128 | 974.3 | 496.1
f16 B=64, M=512, H=16, K=16 | 1095.1 | 1180.2
b16 B=64, M=512, H=16, K=16 | 1099.8 | 1188.7
f16 B=64, M=512, H=16, K=32 | 1657.9 | 1487.9
b16 B=64, M=512, H=16, K=32 | 1665.4 | 1497.0
f16 B=64, M=512, H=16, K=64 | 2927.7 | 2291.3
b16 B=64, M=512, H=16, K=64 | 2949.2 | 2302.8
f16 B=64, M=512, H=16, K=128 | 6573.7 | 5139.4
b16 B=64, M=512, H=16, K=128 | 6632.5 | 5169.0
f16 B=64, M=1024, H=16, K=16 | 3570.0 | 4673.2
b16 B=64, M=1024, H=16, K=16 | 3582.5 | 4671.1
f16 B=64, M=1024, H=16, K=32 | 6134.0 | 5590.1
b16 B=64, M=1024, H=16, K=32 | 6145.1 | 5607.8
f16 B=64, M=1024, H=16, K=64 | 9832.3 | 7863.1
b16 B=64, M=1024, H=16, K=64 | 9859.6 | 7890.6
f16 B=64, M=1024, H=16, K=128 | 20411.4 | 18492.4
b16 B=64, M=1024, H=16, K=128 | 20521.8 | 18545.9
Times are in microseconds (us).
[ attention backward (attn_bias=<class 'xformers.ops.LowerTriangularMask'>) ]
| optimized | fctls_bflsh
1 threads: ------------------------------------------------------
f16 B=384, M=197, H=1, K=64 | 542.1 | 241.0
b16 B=384, M=197, H=1, K=64 | 344.0 | 236.7
f16 B=1024, M=197, H=1, K=64 | 852.4 | 539.2
b16 B=1024, M=197, H=1, K=64 | 857.6 | 544.6
f16 B=32, M=197, H=16, K=64 | 449.0 | 280.6
b16 B=32, M=197, H=16, K=64 | 451.6 | 283.2
f16 B=32, M=197, H=16, K=128 | 1161.4 | 517.8
b16 B=32, M=197, H=16, K=128 | 1030.0 | 518.8
f16 B=16, M=197, H=16, K=64 | 312.1 | 214.0
b16 B=16, M=197, H=16, K=64 | 309.8 | 213.7
f16 B=16, M=197, H=16, K=128 | 547.4 | 296.6
b16 B=16, M=197, H=16, K=128 | 548.1 | 298.1
f16 B=1, M=4096, H=160, K=128 | 25831.9 | 31365.0
b16 B=1, M=4096, H=160, K=128 | 25811.2 | 31382.5
f16 B=2, M=4096, H=160, K=128 | 51084.6 | 48413.3
b16 B=2, M=4096, H=160, K=128 | 51265.2 | 48371.8
f16 B=1, M=8192, H=160, K=128 | 91356.3 | 121976.0
b16 B=1, M=8192, H=160, K=128 | 91520.0 | 122016.8
f16 B=2, M=8192, H=160, K=128 | 181984.4 | 187070.0
b16 B=2, M=8192, H=160, K=128 | 182646.4 | 187302.9
f16 B=1024, M=82, H=8, K=64 | 2703.5 | 1502.2
b16 B=1024, M=82, H=8, K=64 | 2615.9 | 1511.7
f16 B=150, M=256, H=16, K=64 | 2380.2 | 1505.7
b16 B=150, M=256, H=16, K=64 | 2264.4 | 1521.6
f16 B=64, M=256, H=12, K=64 | 768.4 | 526.4
b16 B=64, M=256, H=12, K=64 | 774.2 | 530.1
f16 B=8, M=2048, H=20, K=128 | 7798.2 | 8299.4
b16 B=8, M=2048, H=20, K=128 | 7678.0 | 8323.0
f16 B=16, M=128, H=16, K=16 | 322.9 | 212.5
b16 B=16, M=128, H=16, K=16 | 311.7 | 192.5
f16 B=16, M=128, H=16, K=32 | 359.6 | 212.4
b16 B=16, M=128, H=16, K=32 | 333.6 | 194.5
f16 B=16, M=128, H=16, K=64 | 330.3 | 196.0
b16 B=16, M=128, H=16, K=64 | 314.8 | 196.9
f16 B=16, M=128, H=16, K=128 | 489.5 | 215.7
b16 B=16, M=128, H=16, K=128 | 343.9 | 215.5
f16 B=16, M=512, H=16, K=16 | 463.5 | 261.1
b16 B=16, M=512, H=16, K=16 | 318.1 | 263.3
f16 B=16, M=512, H=16, K=32 | 573.5 | 337.3
b16 B=16, M=512, H=16, K=32 | 391.6 | 339.6
f16 B=16, M=512, H=16, K=64 | 829.6 | 517.0
b16 B=16, M=512, H=16, K=64 | 666.4 | 521.2
f16 B=16, M=512, H=16, K=128 | 1608.1 | 1034.4
b16 B=16, M=512, H=16, K=128 | 1472.1 | 1032.4
f16 B=16, M=1024, H=16, K=16 | 866.3 | 792.3
b16 B=16, M=1024, H=16, K=16 | 684.4 | 793.2
f16 B=16, M=1024, H=16, K=32 | 1269.0 | 1022.8
b16 B=16, M=1024, H=16, K=32 | 1056.2 | 1023.4
f16 B=16, M=1024, H=16, K=64 | 1885.2 | 1563.7
b16 B=16, M=1024, H=16, K=64 | 1749.4 | 1564.6
f16 B=16, M=1024, H=16, K=128 | 4052.5 | 3428.0
b16 B=16, M=1024, H=16, K=128 | 3929.8 | 3432.8
f16 B=64, M=128, H=16, K=16 | 313.7 | 218.0
b16 B=64, M=128, H=16, K=16 | 336.9 | 217.6
f16 B=64, M=128, H=16, K=32 | 317.4 | 212.5
b16 B=64, M=128, H=16, K=32 | 318.1 | 208.0
f16 B=64, M=128, H=16, K=64 | 470.3 | 284.8
b16 B=64, M=128, H=16, K=64 | 474.6 | 286.9
f16 B=64, M=128, H=16, K=128 | 1024.8 | 504.8
b16 B=64, M=128, H=16, K=128 | 1030.2 | 507.4
f16 B=64, M=512, H=16, K=16 | 909.5 | 924.7
b16 B=64, M=512, H=16, K=16 | 913.6 | 930.5
f16 B=64, M=512, H=16, K=32 | 1454.8 | 1176.4
b16 B=64, M=512, H=16, K=32 | 1459.1 | 1181.8
f16 B=64, M=512, H=16, K=64 | 2460.3 | 1752.1
b16 B=64, M=512, H=16, K=64 | 2485.5 | 1773.0
f16 B=64, M=512, H=16, K=128 | 5503.4 | 3564.6
b16 B=64, M=512, H=16, K=128 | 5557.4 | 3592.5
f16 B=64, M=1024, H=16, K=16 | 2599.1 | 2866.5
b16 B=64, M=1024, H=16, K=16 | 2605.9 | 2868.9
f16 B=64, M=1024, H=16, K=32 | 4017.3 | 3577.7
b16 B=64, M=1024, H=16, K=32 | 4022.4 | 3592.0
f16 B=64, M=1024, H=16, K=64 | 6648.5 | 5349.7
b16 B=64, M=1024, H=16, K=64 | 6716.9 | 5374.7
f16 B=64, M=1024, H=16, K=128 | 15206.5 | 11777.5
b16 B=64, M=1024, H=16, K=128 | 15313.3 | 11814.0
Times are in microseconds (us).
TODO:
- [x] get bf16 working
- [x] get non-causal working
- [x] benchmarking
- [x] specify when op can be used
- [ ] flash bwd, triton fwd
- [ ] packed
- [ ] different block sizes for N and M
- [ ] add Triton Autotune
Before submitting
- [ ] Did you have fun?
- Make sure you had fun coding 🙃
- [ ] Did you read the contributor guideline?
- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [ ] N/A
- [ ] Did you make sure to update the docs?
- [ ] N/A
- [ ] Did you write any new necessary tests?
- [ ] N/A
- [ ] Did you update the changelog? (if needed)
- [ ] N/A
PR review
Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
hey @dianaml0 this is great but you would need a modern triton, right ? I've had a branch up on that for a while, some of the API changed and we need to adapt a lot of the layers. Raising that just in case you bumped into that
hey @dianaml0 this is great but you would need a modern triton, right ? I've had a branch up on that for a while, some of the API changed and we need to adapt a lot of the layers. Raising that just in case you bumped into that
Thanks @blefaudeux, that's a good point. I have the updated Triton locally and still facing some errors but still that's needed. Do you think you'll push the changes you have or should I look into it?
hey @dianaml0 this is great but you would need a modern triton, right ? I've had a branch up on that for a while, some of the API changed and we need to adapt a lot of the layers. Raising that just in case you bumped into that
Thanks @blefaudeux, that's a good point. I have the updated Triton locally and still facing some errors but still that's needed. Do you think you'll push the changes you have or should I look into it?
https://github.com/facebookresearch/xformers/pull/483 should help, it should accept any modern triton pip package !
Updated Numbers
Performance Compared to Vanilla FWD
[--------- attention (attn_bias=<class 'NoneType'>) --------]
| optimized | vanilla
1 threads: --------------------------------------------------
f16 B=384, M=197, H=1, K=88 | 1258.8 | 372.9
b16 B=384, M=197, H=1, K=88 | 1270.5 | 375.4
f16 B=384, M=197, H=1, K=80 | 146.8 | 344.5
b16 B=384, M=197, H=1, K=80 | 149.5 | 346.8
f16 B=384, M=197, H=1, K=64 | 90.1 | 293.4
b16 B=384, M=197, H=1, K=64 | 94.3 | 295.4
f16 B=1024, M=197, H=1, K=88 | 3241.4 | 938.6
b16 B=1024, M=197, H=1, K=88 | 3263.3 | 945.0
f16 B=1024, M=197, H=1, K=80 | 349.2 | 865.1
b16 B=1024, M=197, H=1, K=80 | 354.9 | 871.3
f16 B=1024, M=197, H=1, K=64 | 213.8 | 729.5
b16 B=1024, M=197, H=1, K=64 | 222.1 | 735.7
f16 B=512, M=197, H=1, K=80 | 185.6 | 448.6
b16 B=512, M=197, H=1, K=80 | 188.8 | 451.6
f16 B=32, M=197, H=16, K=80 | 193.8 | 548.3
b16 B=32, M=197, H=16, K=80 | 193.9 | 551.0
f16 B=32, M=197, H=16, K=64 | 114.4 | 466.7
b16 B=32, M=197, H=16, K=64 | 120.5 | 469.2
f16 B=32, M=197, H=16, K=128 | 193.9 | 717.6
b16 B=32, M=197, H=16, K=128 | 195.0 | 720.7
f16 B=256, M=197, H=1, K=88 | 868.3 | 260.9
b16 B=256, M=197, H=1, K=88 | 870.0 | 262.3
f16 B=16, M=197, H=16, K=88 | 869.9 | 317.9
b16 B=16, M=197, H=16, K=88 | 870.8 | 319.3
f16 B=16, M=197, H=16, K=64 | 89.0 | 257.0
b16 B=16, M=197, H=16, K=64 | 88.2 | 258.2
f16 B=16, M=197, H=16, K=128 | 101.6 | 385.5
b16 B=16, M=197, H=16, K=128 | 102.2 | 387.0
f16 B=1, M=4096, H=160, K=128 | 9824.7 | 20462.2
b16 B=1, M=4096, H=160, K=128 | 10063.9 | 21439.0
f16 B=2, M=4096, H=160, K=128 | 19368.3 | 42648.8
b16 B=2, M=4096, H=160, K=128 | 19877.8 | 44515.5
f16 B=1, M=8192, H=160, K=128 | 37335.5 | 88470.3
b16 B=1, M=8192, H=160, K=128 | 38634.4 | 87261.8
f16 B=2, M=8192, H=160, K=128 | 74047.2 |
b16 B=2, M=8192, H=160, K=128 | 76836.0 |
f16 B=1024, M=82, H=8, K=64 | 460.9 | 1769.1
b16 B=1024, M=82, H=8, K=64 | 480.2 | 1863.2
f16 B=150, M=256, H=16, K=64 | 383.1 | 1725.0
b16 B=150, M=256, H=16, K=64 | 421.2 | 1760.5
f16 B=64, M=256, H=12, K=64 | 131.6 | 587.0
b16 B=64, M=256, H=12, K=64 | 145.8 | 597.5
f16 B=1, M=4096, H=16, K=40 | 30691.3 | 1943.9
b16 B=1, M=4096, H=16, K=40 | 30831.1 | 1998.1
f16 B=1, M=16384, H=16, K=40 | 422956.0 | 28855.4
b16 B=1, M=16384, H=16, K=40 | 420924.0 | 30204.7
f16 B=256, M=4096, H=16, K=64 | 118507.9 |
b16 B=256, M=4096, H=16, K=64 | 132804.6 |
f16 B=8, M=2048, H=20, K=128 | 2697.1 | 5700.7
b16 B=8, M=2048, H=20, K=128 | 2744.9 | 6041.7
f16 B=16, M=128, H=16, K=16 | 87.3 | 139.1
b16 B=16, M=128, H=16, K=16 | 87.1 | 137.9
f16 B=16, M=128, H=16, K=32 | 89.6 | 139.4
b16 B=16, M=128, H=16, K=32 | 89.2 | 138.3
f16 B=16, M=128, H=16, K=64 | 86.7 | 137.1
b16 B=16, M=128, H=16, K=64 | 87.0 | 136.9
f16 B=16, M=128, H=16, K=128 | 88.7 | 139.6
b16 B=16, M=128, H=16, K=128 | 86.6 | 140.6
f16 B=16, M=512, H=16, K=16 | 87.6 | 461.5
b16 B=16, M=512, H=16, K=16 | 99.5 | 553.9
f16 B=16, M=512, H=16, K=32 | 105.4 | 512.3
b16 B=16, M=512, H=16, K=32 | 123.2 | 594.6
f16 B=16, M=512, H=16, K=64 | 152.1 | 596.1
b16 B=16, M=512, H=16, K=64 | 169.4 | 616.0
f16 B=16, M=512, H=16, K=128 | 318.6 | 786.9
b16 B=16, M=512, H=16, K=128 | 327.6 | 804.4
f16 B=16, M=1024, H=16, K=16 | 292.0 | 1642.8
b16 B=16, M=1024, H=16, K=16 | 384.8 | 2028.9
f16 B=16, M=1024, H=16, K=32 | 369.4 | 1731.9
b16 B=16, M=1024, H=16, K=32 | 428.1 | 2126.7
f16 B=16, M=1024, H=16, K=64 | 518.5 | 2034.7
b16 B=16, M=1024, H=16, K=64 | 579.7 | 2077.2
f16 B=16, M=1024, H=16, K=128 | 1097.8 | 2421.0
b16 B=16, M=1024, H=16, K=128 | 1134.9 | 2489.6
f16 B=64, M=128, H=16, K=16 | 89.3 | 183.7
b16 B=64, M=128, H=16, K=16 | 87.1 | 185.1
f16 B=64, M=128, H=16, K=32 | 87.5 | 224.8
b16 B=64, M=128, H=16, K=32 | 88.1 | 225.5
f16 B=64, M=128, H=16, K=64 | 86.7 | 321.5
b16 B=64, M=128, H=16, K=64 | 88.1 | 323.1
f16 B=64, M=128, H=16, K=128 | 127.1 | 484.8
b16 B=64, M=128, H=16, K=128 | 129.5 | 486.1
f16 B=64, M=512, H=16, K=16 | 307.5 | 1727.2
b16 B=64, M=512, H=16, K=16 | 403.6 | 2094.0
f16 B=64, M=512, H=16, K=32 | 387.3 | 1893.5
b16 B=64, M=512, H=16, K=32 | 450.5 | 2250.3
f16 B=64, M=512, H=16, K=64 | 558.2 | 2234.4
b16 B=64, M=512, H=16, K=64 | 620.5 | 2294.3
f16 B=64, M=512, H=16, K=128 | 1201.0 | 3016.3
b16 B=64, M=512, H=16, K=128 | 1235.3 | 3069.4
f16 B=64, M=1024, H=16, K=16 | 1101.7 | 6428.2
b16 B=64, M=1024, H=16, K=16 | 1473.2 | 8038.7
f16 B=64, M=1024, H=16, K=32 | 1411.1 | 6770.6
b16 B=64, M=1024, H=16, K=32 | 1667.5 | 8383.2
f16 B=64, M=1024, H=16, K=64 | 2033.4 | 7978.2
b16 B=64, M=1024, H=16, K=64 | 2265.9 | 8160.9
f16 B=64, M=1024, H=16, K=128 | 4281.6 | 9521.1
b16 B=64, M=1024, H=16, K=128 | 4413.3 | 9886.2
Times are in microseconds (us).
[ attention (attn_bias=<class 'xformers.ops.LowerTriangularMask'>) ]
| optimized | vanilla
1 threads: ---------------------------------------------------
f16 B=384, M=197, H=1, K=88 | 3697.6 | 446.9
b16 B=384, M=197, H=1, K=88 | 1178.0 | 452.8
f16 B=384, M=197, H=1, K=80 | 120.3 | 421.8
b16 B=384, M=197, H=1, K=80 | 122.6 | 427.6
f16 B=384, M=197, H=1, K=64 | 87.6 | 370.9
b16 B=384, M=197, H=1, K=64 | 90.4 | 376.4
f16 B=1024, M=197, H=1, K=88 | 9588.0 | 1124.2
b16 B=1024, M=197, H=1, K=88 | 2983.1 | 1140.5
f16 B=1024, M=197, H=1, K=80 | 294.1 | 1059.7
b16 B=1024, M=197, H=1, K=80 | 300.2 | 1075.1
f16 B=1024, M=197, H=1, K=64 | 187.5 | 925.9
b16 B=1024, M=197, H=1, K=64 | 197.2 | 940.5
f16 B=512, M=197, H=1, K=80 | 153.7 | 549.6
b16 B=512, M=197, H=1, K=80 | 156.5 | 557.4
f16 B=32, M=197, H=16, K=80 | 157.5 | 647.7
b16 B=32, M=197, H=16, K=80 | 160.8 | 655.5
f16 B=32, M=197, H=16, K=64 | 103.8 | 565.7
b16 B=32, M=197, H=16, K=64 | 109.3 | 573.3
f16 B=32, M=197, H=16, K=128 | 157.6 | 811.7
b16 B=32, M=197, H=16, K=128 | 159.5 | 819.3
f16 B=256, M=197, H=1, K=88 | 2552.6 | 311.0
b16 B=256, M=197, H=1, K=88 | 812.5 | 315.1
f16 B=16, M=197, H=16, K=88 | 2555.6 | 368.1
b16 B=16, M=197, H=16, K=88 | 805.8 | 372.4
f16 B=16, M=197, H=16, K=64 | 87.2 | 307.2
b16 B=16, M=197, H=16, K=64 | 88.9 | 311.0
f16 B=16, M=197, H=16, K=128 | 87.7 | 435.1
b16 B=16, M=197, H=16, K=128 | 87.1 | 438.6
f16 B=1, M=4096, H=160, K=128 | 5174.9 | 37307.0
b16 B=1, M=4096, H=160, K=128 | 5338.7 | 37824.6
f16 B=2, M=4096, H=160, K=128 | 10220.4 | 76294.3
b16 B=2, M=4096, H=160, K=128 | 10545.0 | 77265.7
f16 B=1, M=8192, H=160, K=128 | 19334.4 | 152522.4
b16 B=1, M=8192, H=160, K=128 | 19974.0 | 148646.9
f16 B=2, M=8192, H=160, K=128 | 38377.2 |
b16 B=2, M=8192, H=160, K=128 | 39641.7 |
f16 B=1024, M=82, H=8, K=64 | 488.3 | 1981.4
b16 B=1024, M=82, H=8, K=64 | 515.8 | 2083.8
f16 B=150, M=256, H=16, K=64 | 335.1 | 2402.9
b16 B=150, M=256, H=16, K=64 | 350.9 | 2445.6
f16 B=64, M=256, H=12, K=64 | 118.9 | 805.2
b16 B=64, M=256, H=12, K=64 | 124.6 | 819.3
f16 B=1, M=4096, H=16, K=40 | 14939.0 | 3432.0
b16 B=1, M=4096, H=16, K=40 | 14915.2 | 3468.0
f16 B=1, M=16384, H=16, K=40 | 222056.7 | 55382.5
b16 B=1, M=16384, H=16, K=40 | 218015.9 | 55427.4
f16 B=256, M=4096, H=16, K=64 | 67733.2 |
b16 B=256, M=4096, H=16, K=64 | 74031.3 |
f16 B=8, M=2048, H=20, K=128 | 1488.1 | 9349.8
b16 B=8, M=2048, H=20, K=128 | 1534.1 | 9538.5
f16 B=16, M=128, H=16, K=16 | 86.9 | 147.4
b16 B=16, M=128, H=16, K=16 | 87.9 | 143.6
f16 B=16, M=128, H=16, K=32 | 86.6 | 144.3
b16 B=16, M=128, H=16, K=32 | 89.8 | 143.9
f16 B=16, M=128, H=16, K=64 | 86.4 | 143.1
b16 B=16, M=128, H=16, K=64 | 86.5 | 141.4
f16 B=16, M=128, H=16, K=128 | 89.7 | 160.0
b16 B=16, M=128, H=16, K=128 | 88.3 | 161.9
f16 B=16, M=512, H=16, K=16 | 90.7 | 715.3
b16 B=16, M=512, H=16, K=16 | 87.6 | 792.0
f16 B=16, M=512, H=16, K=32 | 88.5 | 766.8
b16 B=16, M=512, H=16, K=32 | 98.1 | 832.3
f16 B=16, M=512, H=16, K=64 | 120.3 | 881.4
b16 B=16, M=512, H=16, K=64 | 129.1 | 910.7
f16 B=16, M=512, H=16, K=128 | 225.4 | 1066.2
b16 B=16, M=512, H=16, K=128 | 232.0 | 1095.8
f16 B=16, M=1024, H=16, K=16 | 196.6 | 2658.4
b16 B=16, M=1024, H=16, K=16 | 253.0 | 3042.3
f16 B=16, M=1024, H=16, K=32 | 244.2 | 2749.6
b16 B=16, M=1024, H=16, K=32 | 291.8 | 3123.0
f16 B=16, M=1024, H=16, K=64 | 355.3 | 2962.7
b16 B=16, M=1024, H=16, K=64 | 384.7 | 3287.6
f16 B=16, M=1024, H=16, K=128 | 683.0 | 3533.6
b16 B=16, M=1024, H=16, K=128 | 705.0 | 3714.0
f16 B=64, M=128, H=16, K=16 | 87.6 | 246.0
b16 B=64, M=128, H=16, K=16 | 87.2 | 251.8
f16 B=64, M=128, H=16, K=32 | 87.1 | 294.3
b16 B=64, M=128, H=16, K=32 | 87.5 | 298.8
f16 B=64, M=128, H=16, K=64 | 87.8 | 391.4
b16 B=64, M=128, H=16, K=64 | 86.8 | 396.0
f16 B=64, M=128, H=16, K=128 | 130.9 | 557.1
b16 B=64, M=128, H=16, K=128 | 133.1 | 562.3
f16 B=64, M=512, H=16, K=16 | 227.5 | 2704.1
b16 B=64, M=512, H=16, K=16 | 288.9 | 3019.5
f16 B=64, M=512, H=16, K=32 | 285.1 | 2892.2
b16 B=64, M=512, H=16, K=32 | 336.9 | 3163.7
f16 B=64, M=512, H=16, K=64 | 414.2 | 3352.4
b16 B=64, M=512, H=16, K=64 | 444.3 | 3470.6
f16 B=64, M=512, H=16, K=128 | 828.5 | 4093.5
b16 B=64, M=512, H=16, K=128 | 854.7 | 4208.4
f16 B=64, M=1024, H=16, K=16 | 721.9 | 10475.9
b16 B=64, M=1024, H=16, K=16 | 932.5 | 12026.3
f16 B=64, M=1024, H=16, K=32 | 900.6 | 10855.0
b16 B=64, M=1024, H=16, K=32 | 1075.5 | 12354.0
f16 B=64, M=1024, H=16, K=64 | 1314.2 | 11673.0
b16 B=64, M=1024, H=16, K=64 | 1423.5 | 13000.3
f16 B=64, M=1024, H=16, K=128 | 2608.3 | 13961.5
b16 B=64, M=1024, H=16, K=128 | 2693.0 | 14694.5
Times are in microseconds (us).
Performance Compared to Vanilla BWD
[---- attention backward (attn_bias=<class 'NoneType'>) ----]
| optimized | vanilla
1 threads: --------------------------------------------------
f16 B=384, M=197, H=1, K=88 | 3679.0 | 820.0
b16 B=384, M=197, H=1, K=88 | 3560.9 | 823.2
f16 B=384, M=197, H=1, K=80 | 762.1 | 767.0
b16 B=384, M=197, H=1, K=80 | 762.6 | 768.6
f16 B=384, M=197, H=1, K=64 | 543.5 | 651.8
b16 B=384, M=197, H=1, K=64 | 372.8 | 652.8
f16 B=1024, M=197, H=1, K=88 | 9127.6 | 2103.8
b16 B=1024, M=197, H=1, K=88 | 9123.4 | 2104.3
f16 B=1024, M=197, H=1, K=80 | 1881.4 | 1957.8
b16 B=1024, M=197, H=1, K=80 | 1889.3 | 1957.7
f16 B=1024, M=197, H=1, K=64 | 916.7 | 1648.4
b16 B=1024, M=197, H=1, K=64 | 924.1 | 1649.6
f16 B=512, M=197, H=1, K=80 | 987.0 | 993.8
b16 B=512, M=197, H=1, K=80 | 991.6 | 995.8
f16 B=32, M=197, H=16, K=80 | 1038.9 | 1032.6
b16 B=32, M=197, H=16, K=80 | 1040.7 | 1035.2
f16 B=32, M=197, H=16, K=64 | 485.6 | 886.1
b16 B=32, M=197, H=16, K=64 | 487.9 | 887.0
f16 B=32, M=197, H=16, K=128 | 1165.6 | 1341.5
b16 B=32, M=197, H=16, K=128 | 1168.8 | 1344.7
f16 B=256, M=197, H=1, K=88 | 2364.1 | 576.1
b16 B=256, M=197, H=1, K=88 | 2363.5 | 577.3
f16 B=16, M=197, H=16, K=88 | 2378.9 | 599.2
b16 B=16, M=197, H=16, K=88 | 2377.1 | 600.3
f16 B=16, M=197, H=16, K=64 | 355.0 | 486.9
b16 B=16, M=197, H=16, K=64 | 309.7 | 488.6
f16 B=16, M=197, H=16, K=128 | 617.3 | 713.9
b16 B=16, M=197, H=16, K=128 | 620.4 | 715.8
f16 B=1, M=4096, H=160, K=128 | 41943.7 | 38971.0
b16 B=1, M=4096, H=160, K=128 | 42033.0 | 39938.5
f16 B=2, M=4096, H=160, K=128 | 83588.1 | 79017.0
b16 B=2, M=4096, H=160, K=128 | 83792.2 | 80856.7
f16 B=1, M=8192, H=160, K=128 | 160571.8 |
b16 B=1, M=8192, H=160, K=128 | 160843.2 |
f16 B=2, M=8192, H=160, K=128 | 320831.4 |
b16 B=2, M=8192, H=160, K=128 | 322116.0 |
f16 B=1024, M=82, H=8, K=64 | 2629.7 | 3609.1
b16 B=1024, M=82, H=8, K=64 | 2524.3 | 3782.4
f16 B=150, M=256, H=16, K=64 | 2577.6 | 3790.4
b16 B=150, M=256, H=16, K=64 | 2458.9 | 3837.4
f16 B=64, M=256, H=12, K=64 | 827.1 | 1254.5
b16 B=64, M=256, H=12, K=64 | 835.7 | 1267.6
f16 B=1, M=4096, H=16, K=40 | 43086.4 | 3562.3
b16 B=1, M=4096, H=16, K=40 | 43079.7 | 3588.6
f16 B=1, M=16384, H=16, K=40 | 664389.3 | 53939.0
b16 B=1, M=16384, H=16, K=40 | 663918.7 | 54291.3
f16 B=256, M=4096, H=16, K=64 | 504278.4 |
b16 B=256, M=4096, H=16, K=64 | 503469.4 |
f16 B=8, M=2048, H=20, K=128 | 11424.6 | 10535.0
b16 B=8, M=2048, H=20, K=128 | 11365.6 | 10784.8
f16 B=16, M=128, H=16, K=16 | 351.2 | 348.4
b16 B=16, M=128, H=16, K=16 | 329.6 | 346.4
f16 B=16, M=128, H=16, K=32 | 351.3 | 343.9
b16 B=16, M=128, H=16, K=32 | 321.3 | 319.8
f16 B=16, M=128, H=16, K=64 | 329.9 | 316.2
b16 B=16, M=128, H=16, K=64 | 329.8 | 340.4
f16 B=16, M=128, H=16, K=128 | 453.7 | 338.1
b16 B=16, M=128, H=16, K=128 | 332.6 | 337.6
f16 B=16, M=512, H=16, K=16 | 526.4 | 985.4
b16 B=16, M=512, H=16, K=16 | 322.8 | 1078.1
f16 B=16, M=512, H=16, K=32 | 683.4 | 1089.3
b16 B=16, M=512, H=16, K=32 | 486.9 | 1178.8
f16 B=16, M=512, H=16, K=64 | 1020.1 | 1276.7
b16 B=16, M=512, H=16, K=64 | 845.3 | 1296.3
f16 B=16, M=512, H=16, K=128 | 1900.1 | 1715.2
b16 B=16, M=512, H=16, K=128 | 1766.8 | 1742.1
f16 B=16, M=1024, H=16, K=16 | 1260.6 | 3525.5
b16 B=16, M=1024, H=16, K=16 | 1067.4 | 3953.2
f16 B=16, M=1024, H=16, K=32 | 1794.2 | 3728.4
b16 B=16, M=1024, H=16, K=32 | 1615.1 | 4153.2
f16 B=16, M=1024, H=16, K=64 | 2685.4 | 4278.5
b16 B=16, M=1024, H=16, K=64 | 2545.9 | 4349.5
f16 B=16, M=1024, H=16, K=128 | 5420.6 | 5108.0
b16 B=16, M=1024, H=16, K=128 | 5317.9 | 5240.9
f16 B=64, M=128, H=16, K=16 | 318.0 | 365.1
b16 B=64, M=128, H=16, K=16 | 333.2 | 373.0
f16 B=64, M=128, H=16, K=32 | 332.6 | 466.7
b16 B=64, M=128, H=16, K=32 | 334.3 | 470.8
f16 B=64, M=128, H=16, K=64 | 464.2 | 680.0
b16 B=64, M=128, H=16, K=64 | 469.7 | 681.6
f16 B=64, M=128, H=16, K=128 | 980.2 | 1076.1
b16 B=64, M=128, H=16, K=128 | 986.8 | 1078.6
f16 B=64, M=512, H=16, K=16 | 1109.1 | 3706.3
b16 B=64, M=512, H=16, K=16 | 1112.0 | 4077.9
f16 B=64, M=512, H=16, K=32 | 1679.8 | 4137.0
b16 B=64, M=512, H=16, K=32 | 1684.1 | 4507.5
f16 B=64, M=512, H=16, K=64 | 2988.8 | 4903.0
b16 B=64, M=512, H=16, K=64 | 3001.1 | 4991.4
f16 B=64, M=512, H=16, K=128 | 6639.6 | 6643.4
b16 B=64, M=512, H=16, K=128 | 6698.5 | 6749.8
f16 B=64, M=1024, H=16, K=16 | 3618.2 | 13888.1
b16 B=64, M=1024, H=16, K=16 | 3625.5 | 15564.5
f16 B=64, M=1024, H=16, K=32 | 6140.2 | 14750.0
b16 B=64, M=1024, H=16, K=32 | 6156.8 | 16452.7
f16 B=64, M=1024, H=16, K=64 | 9818.7 | 16864.4
b16 B=64, M=1024, H=16, K=64 | 9843.5 | 17158.3
f16 B=64, M=1024, H=16, K=128 | 20554.6 | 20403.1
b16 B=64, M=1024, H=16, K=128 | 20676.6 | 20899.8
Times are in microseconds (us).
[ attention backward (attn_bias=<class 'xformers.ops.LowerTriangularMask'>) ]
| optimized | vanilla
1 threads: --------------------------------------------------
f16 B=384, M=197, H=1, K=88 | 9482.2 | 821.1
b16 B=384, M=197, H=1, K=88 | 9334.4 | 823.2
f16 B=384, M=197, H=1, K=80 | 669.6 | 768.4
b16 B=384, M=197, H=1, K=80 | 672.2 | 769.8
f16 B=384, M=197, H=1, K=64 | 544.9 | 651.9
b16 B=384, M=197, H=1, K=64 | 351.0 | 653.7
f16 B=1024, M=197, H=1, K=88 | 23569.2 | 2103.4
b16 B=1024, M=197, H=1, K=88 | 23571.4 | 2105.2
f16 B=1024, M=197, H=1, K=80 | 1644.3 | 1957.5
b16 B=1024, M=197, H=1, K=80 | 1649.9 | 1957.4
f16 B=1024, M=197, H=1, K=64 | 862.7 | 1648.0
b16 B=1024, M=197, H=1, K=64 | 868.7 | 1649.4
f16 B=512, M=197, H=1, K=80 | 863.2 | 993.9
b16 B=512, M=197, H=1, K=80 | 866.3 | 996.2
f16 B=32, M=197, H=16, K=80 | 867.2 | 1035.0
b16 B=32, M=197, H=16, K=80 | 869.5 | 1035.5
f16 B=32, M=197, H=16, K=64 | 456.5 | 884.3
b16 B=32, M=197, H=16, K=64 | 458.6 | 885.4
f16 B=32, M=197, H=16, K=128 | 1042.2 | 1339.5
b16 B=32, M=197, H=16, K=128 | 1046.8 | 1343.0
f16 B=256, M=197, H=1, K=88 | 6343.3 | 575.9
b16 B=256, M=197, H=1, K=88 | 6346.2 | 578.0
f16 B=16, M=197, H=16, K=88 | 6339.5 | 600.1
b16 B=16, M=197, H=16, K=88 | 6338.8 | 601.4
f16 B=16, M=197, H=16, K=64 | 323.2 | 488.7
b16 B=16, M=197, H=16, K=64 | 309.0 | 491.1
f16 B=16, M=197, H=16, K=128 | 555.8 | 715.2
b16 B=16, M=197, H=16, K=128 | 557.6 | 717.0
f16 B=1, M=4096, H=160, K=128 | 26198.1 | 38857.2
b16 B=1, M=4096, H=160, K=128 | 26165.0 | 39906.6
f16 B=2, M=4096, H=160, K=128 | 51727.5 | 78346.7
b16 B=2, M=4096, H=160, K=128 | 51935.9 | 80796.5
f16 B=1, M=8192, H=160, K=128 | 92706.3 |
b16 B=1, M=8192, H=160, K=128 | 92957.2 |
f16 B=2, M=8192, H=160, K=128 | 184853.7 |
b16 B=2, M=8192, H=160, K=128 | 185265.1 |
f16 B=1024, M=82, H=8, K=64 | 2746.8 | 3611.4
b16 B=1024, M=82, H=8, K=64 | 2641.2 | 3784.6
f16 B=150, M=256, H=16, K=64 | 2381.3 | 3824.6
b16 B=150, M=256, H=16, K=64 | 2291.5 | 3863.9
f16 B=64, M=256, H=12, K=64 | 776.2 | 1254.8
b16 B=64, M=256, H=12, K=64 | 780.0 | 1269.1
f16 B=1, M=4096, H=16, K=40 | 6826.9 | 3563.2
b16 B=1, M=4096, H=16, K=40 | 6622.5 | 3596.4
f16 B=1, M=16384, H=16, K=40 | 94566.7 | 53887.6
b16 B=1, M=16384, H=16, K=40 | 94648.8 | 54316.2
f16 B=256, M=4096, H=16, K=64 | 274458.4 |
b16 B=256, M=4096, H=16, K=64 | 275062.8 |
f16 B=8, M=2048, H=20, K=128 | 7871.9 | 10544.8
b16 B=8, M=2048, H=20, K=128 | 7770.2 | 10794.5
f16 B=16, M=128, H=16, K=16 | 348.7 | 323.2
b16 B=16, M=128, H=16, K=16 | 326.5 | 325.8
f16 B=16, M=128, H=16, K=32 | 350.7 | 323.9
b16 B=16, M=128, H=16, K=32 | 305.9 | 300.9
f16 B=16, M=128, H=16, K=64 | 361.3 | 328.4
b16 B=16, M=128, H=16, K=64 | 309.5 | 296.7
f16 B=16, M=128, H=16, K=128 | 487.7 | 321.4
b16 B=16, M=128, H=16, K=128 | 329.5 | 305.0
f16 B=16, M=512, H=16, K=16 | 435.9 | 986.1
b16 B=16, M=512, H=16, K=16 | 333.1 | 1079.5
f16 B=16, M=512, H=16, K=32 | 589.3 | 1090.4
b16 B=16, M=512, H=16, K=32 | 394.0 | 1178.8
f16 B=16, M=512, H=16, K=64 | 846.2 | 1280.3
b16 B=16, M=512, H=16, K=64 | 670.0 | 1298.3
f16 B=16, M=512, H=16, K=128 | 1616.0 | 1712.2
b16 B=16, M=512, H=16, K=128 | 1485.7 | 1740.0
f16 B=16, M=1024, H=16, K=16 | 881.3 | 3527.4
b16 B=16, M=1024, H=16, K=16 | 686.9 | 3948.4
f16 B=16, M=1024, H=16, K=32 | 1241.7 | 3729.6
b16 B=16, M=1024, H=16, K=32 | 1059.4 | 4152.4
f16 B=16, M=1024, H=16, K=64 | 1858.0 | 4284.3
b16 B=16, M=1024, H=16, K=64 | 1752.9 | 4348.1
f16 B=16, M=1024, H=16, K=128 | 4083.5 | 5106.9
b16 B=16, M=1024, H=16, K=128 | 3969.3 | 5246.5
f16 B=64, M=128, H=16, K=16 | 306.6 | 368.1
b16 B=64, M=128, H=16, K=16 | 333.1 | 376.4
f16 B=64, M=128, H=16, K=32 | 327.0 | 466.3
b16 B=64, M=128, H=16, K=32 | 329.4 | 472.3
f16 B=64, M=128, H=16, K=64 | 475.4 | 680.6
b16 B=64, M=128, H=16, K=64 | 479.8 | 680.9
f16 B=64, M=128, H=16, K=128 | 1041.1 | 1071.8
b16 B=64, M=128, H=16, K=128 | 1046.3 | 1077.4
f16 B=64, M=512, H=16, K=16 | 914.1 | 3704.1
b16 B=64, M=512, H=16, K=16 | 916.3 | 4083.0
f16 B=64, M=512, H=16, K=32 | 1458.5 | 4134.4
b16 B=64, M=512, H=16, K=32 | 1461.1 | 4502.1
f16 B=64, M=512, H=16, K=64 | 2464.9 | 4900.6
b16 B=64, M=512, H=16, K=64 | 2490.6 | 4991.0
f16 B=64, M=512, H=16, K=128 | 5549.3 | 6642.1
b16 B=64, M=512, H=16, K=128 | 5594.7 | 6753.4
f16 B=64, M=1024, H=16, K=16 | 2605.9 | 13864.0
b16 B=64, M=1024, H=16, K=16 | 2613.2 | 15564.9
f16 B=64, M=1024, H=16, K=32 | 4026.7 | 14759.8
b16 B=64, M=1024, H=16, K=32 | 4029.4 | 16449.0
f16 B=64, M=1024, H=16, K=64 | 6661.4 | 16899.4
b16 B=64, M=1024, H=16, K=64 | 6721.7 | 17166.3
f16 B=64, M=1024, H=16, K=128 | 15343.3 | 20357.1
b16 B=64, M=1024, H=16, K=128 | 15446.7 | 20917.8
Times are in microseconds (us).
Thanks a lot for the reviews @fmassa and @danthe3rd ! I've made some changes and added an op for Triton fwd with Flash bwd
Forwards for Triton fwd and Flash bwd:
[--------- attention (attn_bias=<class 'NoneType'>) --------]
| optimized | eager
1 threads: --------------------------------------------------
f16 B=384, M=197, H=1, K=64 | 91.7 | 292.3
b16 B=384, M=197, H=1, K=64 | 94.2 | 295.4
f16 B=1024, M=197, H=1, K=64 | 213.1 | 727.8
b16 B=1024, M=197, H=1, K=64 | 221.4 | 734.1
f16 B=32, M=197, H=16, K=64 | 114.2 | 466.1
b16 B=32, M=197, H=16, K=64 | 120.6 | 468.4
f16 B=32, M=197, H=16, K=128 | 194.0 | 716.8
b16 B=32, M=197, H=16, K=128 | 195.4 | 719.7
f16 B=16, M=197, H=16, K=64 | 91.5 | 256.3
b16 B=16, M=197, H=16, K=64 | 90.9 | 256.9
f16 B=16, M=197, H=16, K=128 | 102.4 | 384.7
b16 B=16, M=197, H=16, K=128 | 104.6 | 386.0
f16 B=1, M=4096, H=160, K=128 | 9785.7 | 20529.6
b16 B=1, M=4096, H=160, K=128 | 10020.5 | 21512.5
f16 B=2, M=4096, H=160, K=128 | 19305.3 | 42686.7
b16 B=2, M=4096, H=160, K=128 | 19817.6 | 44503.4
f16 B=1, M=8192, H=160, K=128 | 37278.0 | 88499.1
b16 B=1, M=8192, H=160, K=128 | 38494.5 | 87043.0
f16 B=2, M=8192, H=160, K=128 | 73829.8 |
b16 B=2, M=8192, H=160, K=128 | 76805.0 |
f16 B=1024, M=82, H=8, K=64 | 460.8 | 1769.1
b16 B=1024, M=82, H=8, K=64 | 478.7 | 1864.2
f16 B=150, M=256, H=16, K=64 | 383.5 | 1727.0
b16 B=150, M=256, H=16, K=64 | 420.7 | 1761.2
f16 B=64, M=256, H=12, K=64 | 131.7 | 588.1
b16 B=64, M=256, H=12, K=64 | 145.6 | 598.2
f16 B=256, M=4096, H=16, K=64 | 118090.6 |
b16 B=256, M=4096, H=16, K=64 | 132264.8 |
f16 B=8, M=2048, H=20, K=128 | 2692.1 | 5746.4
b16 B=8, M=2048, H=20, K=128 | 2743.3 | 6049.4
f16 B=16, M=128, H=16, K=16 | 90.0 | 145.5
b16 B=16, M=128, H=16, K=16 | 92.8 | 143.1
f16 B=16, M=128, H=16, K=32 | 93.7 | 147.5
b16 B=16, M=128, H=16, K=32 | 90.5 | 145.4
f16 B=16, M=128, H=16, K=64 | 90.4 | 146.0
b16 B=16, M=128, H=16, K=64 | 92.5 | 142.1
f16 B=16, M=128, H=16, K=128 | 91.0 | 142.3
b16 B=16, M=128, H=16, K=128 | 90.2 | 143.1
f16 B=16, M=512, H=16, K=16 | 94.1 | 462.3
b16 B=16, M=512, H=16, K=16 | 100.9 | 553.7
f16 B=16, M=512, H=16, K=32 | 105.4 | 513.1
b16 B=16, M=512, H=16, K=32 | 122.9 | 595.7
f16 B=16, M=512, H=16, K=64 | 152.1 | 596.8
b16 B=16, M=512, H=16, K=64 | 169.2 | 613.5
f16 B=16, M=512, H=16, K=128 | 319.0 | 788.5
b16 B=16, M=512, H=16, K=128 | 327.9 | 805.1
f16 B=16, M=1024, H=16, K=16 | 291.8 | 1644.3
b16 B=16, M=1024, H=16, K=16 | 384.1 | 2026.1
f16 B=16, M=1024, H=16, K=32 | 369.0 | 1734.6
b16 B=16, M=1024, H=16, K=32 | 426.7 | 2128.1
f16 B=16, M=1024, H=16, K=64 | 521.1 | 2035.3
b16 B=16, M=1024, H=16, K=64 | 578.1 | 2079.1
f16 B=16, M=1024, H=16, K=128 | 1093.8 | 2420.7
b16 B=16, M=1024, H=16, K=128 | 1131.1 | 2491.4
f16 B=64, M=128, H=16, K=16 | 90.9 | 182.8
b16 B=64, M=128, H=16, K=16 | 90.3 | 184.6
f16 B=64, M=128, H=16, K=32 | 90.6 | 225.4
b16 B=64, M=128, H=16, K=32 | 90.3 | 226.7
f16 B=64, M=128, H=16, K=64 | 90.9 | 324.6
b16 B=64, M=128, H=16, K=64 | 89.9 | 326.3
f16 B=64, M=128, H=16, K=128 | 128.4 | 485.0
b16 B=64, M=128, H=16, K=128 | 130.8 | 486.5
f16 B=64, M=512, H=16, K=16 | 305.7 | 1729.1
b16 B=64, M=512, H=16, K=16 | 401.8 | 2094.5
f16 B=64, M=512, H=16, K=32 | 386.0 | 1899.7
b16 B=64, M=512, H=16, K=32 | 448.9 | 2252.9
f16 B=64, M=512, H=16, K=64 | 561.1 | 2246.8
b16 B=64, M=512, H=16, K=64 | 618.4 | 2302.4
f16 B=64, M=512, H=16, K=128 | 1195.5 | 3013.4
b16 B=64, M=512, H=16, K=128 | 1228.0 | 3069.2
f16 B=64, M=1024, H=16, K=16 | 1097.8 | 6432.0
b16 B=64, M=1024, H=16, K=16 | 1468.3 | 8046.7
f16 B=64, M=1024, H=16, K=32 | 1405.7 | 6777.6
b16 B=64, M=1024, H=16, K=32 | 1662.4 | 8426.7
f16 B=64, M=1024, H=16, K=64 | 2035.0 | 7999.1
b16 B=64, M=1024, H=16, K=64 | 2255.6 | 8174.9
f16 B=64, M=1024, H=16, K=128 | 4269.3 | 9546.4
b16 B=64, M=1024, H=16, K=128 | 4404.0 | 9904.9
Times are in microseconds (us).
[ attention (attn_bias=<class 'xformers.ops.memory_efficient_attention.LowerTriangularMask'>) ]
| optimized | eager
1 threads: ---------------------------------------------------
f16 B=384, M=197, H=1, K=64 | 92.4 | 369.8
b16 B=384, M=197, H=1, K=64 | 94.6 | 376.0
f16 B=1024, M=197, H=1, K=64 | 186.6 | 923.4
b16 B=1024, M=197, H=1, K=64 | 196.6 | 938.1
f16 B=32, M=197, H=16, K=64 | 104.4 | 564.2
b16 B=32, M=197, H=16, K=64 | 108.8 | 572.2
f16 B=32, M=197, H=16, K=128 | 157.9 | 812.1
b16 B=32, M=197, H=16, K=128 | 159.5 | 818.4
f16 B=16, M=197, H=16, K=64 | 93.2 | 306.1
b16 B=16, M=197, H=16, K=64 | 90.6 | 310.1
f16 B=16, M=197, H=16, K=128 | 92.2 | 433.2
b16 B=16, M=197, H=16, K=128 | 91.6 | 437.6
f16 B=1, M=4096, H=160, K=128 | 5166.1 | 37295.6
b16 B=1, M=4096, H=160, K=128 | 5326.0 | 37828.6
f16 B=2, M=4096, H=160, K=128 | 10191.6 | 76206.0
b16 B=2, M=4096, H=160, K=128 | 10507.9 | 77209.3
f16 B=1, M=8192, H=160, K=128 | 19285.9 | 152334.7
b16 B=1, M=8192, H=160, K=128 | 19893.0 | 148896.7
f16 B=2, M=8192, H=160, K=128 | 38335.1 |
b16 B=2, M=8192, H=160, K=128 | 39490.4 |
f16 B=1024, M=82, H=8, K=64 | 486.3 | 1981.2
b16 B=1024, M=82, H=8, K=64 | 514.4 | 2086.7
f16 B=150, M=256, H=16, K=64 | 334.3 | 2402.0
b16 B=150, M=256, H=16, K=64 | 350.1 | 2447.2
f16 B=64, M=256, H=12, K=64 | 118.7 | 806.5
b16 B=64, M=256, H=12, K=64 | 124.9 | 820.5
f16 B=256, M=4096, H=16, K=64 | 67692.5 |
b16 B=256, M=4096, H=16, K=64 | 73779.3 |
f16 B=8, M=2048, H=20, K=128 | 1487.2 | 9354.3
b16 B=8, M=2048, H=20, K=128 | 1529.4 | 9535.1
f16 B=16, M=128, H=16, K=16 | 94.6 | 153.5
b16 B=16, M=128, H=16, K=16 | 93.1 | 150.6
f16 B=16, M=128, H=16, K=32 | 91.0 | 154.5
b16 B=16, M=128, H=16, K=32 | 94.2 | 154.0
f16 B=16, M=128, H=16, K=64 | 93.7 | 149.8
b16 B=16, M=128, H=16, K=64 | 93.2 | 152.3
f16 B=16, M=128, H=16, K=128 | 92.0 | 160.4
b16 B=16, M=128, H=16, K=128 | 91.1 | 162.3
f16 B=16, M=512, H=16, K=16 | 93.0 | 715.4
b16 B=16, M=512, H=16, K=16 | 93.9 | 792.7
f16 B=16, M=512, H=16, K=32 | 91.0 | 767.7
b16 B=16, M=512, H=16, K=32 | 98.1 | 833.7
f16 B=16, M=512, H=16, K=64 | 120.8 | 882.7
b16 B=16, M=512, H=16, K=64 | 129.8 | 911.8
f16 B=16, M=512, H=16, K=128 | 225.5 | 1067.1
b16 B=16, M=512, H=16, K=128 | 231.7 | 1096.1
f16 B=16, M=1024, H=16, K=16 | 195.9 | 2658.6
b16 B=16, M=1024, H=16, K=16 | 252.9 | 3043.7
f16 B=16, M=1024, H=16, K=32 | 243.7 | 2752.9
b16 B=16, M=1024, H=16, K=32 | 291.1 | 3121.6
f16 B=16, M=1024, H=16, K=64 | 355.5 | 2966.1
b16 B=16, M=1024, H=16, K=64 | 384.2 | 3287.9
f16 B=16, M=1024, H=16, K=128 | 682.0 | 3533.0
b16 B=16, M=1024, H=16, K=128 | 703.3 | 3713.4
f16 B=64, M=128, H=16, K=16 | 90.4 | 245.3
b16 B=64, M=128, H=16, K=16 | 90.3 | 251.1
f16 B=64, M=128, H=16, K=32 | 91.1 | 294.9
b16 B=64, M=128, H=16, K=32 | 91.8 | 299.4
f16 B=64, M=128, H=16, K=64 | 90.1 | 393.0
b16 B=64, M=128, H=16, K=64 | 92.7 | 397.9
f16 B=64, M=128, H=16, K=128 | 132.3 | 557.3
b16 B=64, M=128, H=16, K=128 | 134.1 | 562.5
f16 B=64, M=512, H=16, K=16 | 227.0 | 2705.2
b16 B=64, M=512, H=16, K=16 | 288.3 | 3017.0
f16 B=64, M=512, H=16, K=32 | 284.5 | 2893.9
b16 B=64, M=512, H=16, K=32 | 336.8 | 3167.5
f16 B=64, M=512, H=16, K=64 | 418.1 | 3338.8
b16 B=64, M=512, H=16, K=64 | 444.2 | 3460.2
f16 B=64, M=512, H=16, K=128 | 830.2 | 4091.8
b16 B=64, M=512, H=16, K=128 | 855.8 | 4207.5
f16 B=64, M=1024, H=16, K=16 | 720.1 | 10487.1
b16 B=64, M=1024, H=16, K=16 | 929.6 | 12038.4
f16 B=64, M=1024, H=16, K=32 | 898.2 | 10852.4
b16 B=64, M=1024, H=16, K=32 | 1072.3 | 12350.3
f16 B=64, M=1024, H=16, K=64 | 1325.3 | 11682.5
b16 B=64, M=1024, H=16, K=64 | 1421.4 | 13010.3
f16 B=64, M=1024, H=16, K=128 | 2604.2 | 13959.2
b16 B=64, M=1024, H=16, K=128 | 2683.7 | 14677.4
Times are in microseconds (us).
Backwards for Triton fwd and Flash bwd:
[---- attention backward (attn_bias=<class 'NoneType'>) ----]
| optimized | vanilla
1 threads: --------------------------------------------------
f16 B=384, M=197, H=1, K=64 | 242.0 | 650.9
b16 B=384, M=197, H=1, K=64 | 221.9 | 652.5
f16 B=1024, M=197, H=1, K=64 | 534.5 | 1647.7
b16 B=1024, M=197, H=1, K=64 | 532.9 | 1648.0
f16 B=32, M=197, H=16, K=64 | 279.1 | 882.9
b16 B=32, M=197, H=16, K=64 | 278.5 | 883.9
f16 B=32, M=197, H=16, K=128 | 638.0 | 1344.0
b16 B=32, M=197, H=16, K=128 | 640.2 | 1346.5
f16 B=16, M=197, H=16, K=64 | 235.9 | 484.8
b16 B=16, M=197, H=16, K=64 | 245.0 | 485.8
f16 B=16, M=197, H=16, K=128 | 367.5 | 711.5
b16 B=16, M=197, H=16, K=128 | 368.5 | 713.9
f16 B=1, M=4096, H=160, K=128 | 53888.1 | 38893.2
b16 B=1, M=4096, H=160, K=128 | 53989.6 | 39876.1
f16 B=2, M=4096, H=160, K=128 | 82534.0 | 78629.0
b16 B=2, M=4096, H=160, K=128 | 82718.4 | 80537.6
f16 B=1, M=8192, H=160, K=128 | 213238.1 |
b16 B=1, M=8192, H=160, K=128 | 213335.0 |
f16 B=2, M=8192, H=160, K=128 | 324974.6 |
b16 B=2, M=8192, H=160, K=128 | 325014.0 |
f16 B=1024, M=82, H=8, K=64 | 1494.2 | 3611.5
b16 B=1024, M=82, H=8, K=64 | 1515.2 | 3787.3
f16 B=150, M=256, H=16, K=64 | 1493.8 | 3800.9
b16 B=150, M=256, H=16, K=64 | 1491.3 | 3843.4
f16 B=64, M=256, H=12, K=64 | 523.3 | 1256.9
b16 B=64, M=256, H=12, K=64 | 520.9 | 1271.3
f16 B=256, M=4096, H=16, K=64 | 430986.6 |
b16 B=256, M=4096, H=16, K=64 | 430693.7 |
f16 B=8, M=2048, H=20, K=128 | 13945.0 | 10546.8
b16 B=8, M=2048, H=20, K=128 | 13939.3 | 10798.0
f16 B=16, M=128, H=16, K=16 | 196.4 | 327.9
b16 B=16, M=128, H=16, K=16 | 197.0 | 353.1
f16 B=16, M=128, H=16, K=32 | 215.2 | 350.5
b16 B=16, M=128, H=16, K=32 | 197.6 | 325.6
f16 B=16, M=128, H=16, K=64 | 195.9 | 346.2
b16 B=16, M=128, H=16, K=64 | 216.0 | 357.0
f16 B=16, M=128, H=16, K=128 | 197.1 | 322.8
b16 B=16, M=128, H=16, K=128 | 218.6 | 320.6
f16 B=16, M=512, H=16, K=16 | 321.2 | 984.5
b16 B=16, M=512, H=16, K=16 | 323.6 | 1077.2
f16 B=16, M=512, H=16, K=32 | 423.1 | 1089.6
b16 B=16, M=512, H=16, K=32 | 425.6 | 1178.5
f16 B=16, M=512, H=16, K=64 | 671.8 | 1285.5
b16 B=16, M=512, H=16, K=64 | 673.3 | 1306.2
f16 B=16, M=512, H=16, K=128 | 1512.0 | 1718.0
b16 B=16, M=512, H=16, K=128 | 1514.4 | 1748.0
f16 B=16, M=1024, H=16, K=16 | 1237.7 | 3532.8
b16 B=16, M=1024, H=16, K=16 | 1240.7 | 3957.6
f16 B=16, M=1024, H=16, K=32 | 1593.6 | 3733.2
b16 B=16, M=1024, H=16, K=32 | 1594.9 | 4156.9
f16 B=16, M=1024, H=16, K=64 | 2309.4 | 4278.0
b16 B=16, M=1024, H=16, K=64 | 2311.7 | 4350.1
f16 B=16, M=1024, H=16, K=128 | 5478.4 | 5110.4
b16 B=16, M=1024, H=16, K=128 | 5498.3 | 5251.9
f16 B=64, M=128, H=16, K=16 | 197.3 | 364.8
b16 B=64, M=128, H=16, K=16 | 215.9 | 372.4
f16 B=64, M=128, H=16, K=32 | 200.8 | 464.8
b16 B=64, M=128, H=16, K=32 | 201.4 | 470.3
f16 B=64, M=128, H=16, K=64 | 283.5 | 680.8
b16 B=64, M=128, H=16, K=64 | 286.8 | 682.9
f16 B=64, M=128, H=16, K=128 | 496.9 | 1076.1
b16 B=64, M=128, H=16, K=128 | 501.1 | 1078.8
f16 B=64, M=512, H=16, K=16 | 1181.5 | 3715.9
b16 B=64, M=512, H=16, K=16 | 1187.2 | 4087.0
f16 B=64, M=512, H=16, K=32 | 1497.3 | 4141.0
b16 B=64, M=512, H=16, K=32 | 1505.0 | 4508.7
f16 B=64, M=512, H=16, K=64 | 2296.0 | 4917.1
b16 B=64, M=512, H=16, K=64 | 2309.8 | 5008.3
f16 B=64, M=512, H=16, K=128 | 5157.6 | 6644.4
b16 B=64, M=512, H=16, K=128 | 5178.1 | 6751.8
f16 B=64, M=1024, H=16, K=16 | 4668.2 | 13869.2
b16 B=64, M=1024, H=16, K=16 | 4671.8 | 15561.4
f16 B=64, M=1024, H=16, K=32 | 5595.4 | 14761.1
b16 B=64, M=1024, H=16, K=32 | 5609.4 | 16450.7
f16 B=64, M=1024, H=16, K=64 | 7872.6 | 16891.0
b16 B=64, M=1024, H=16, K=64 | 7899.7 | 17185.1
f16 B=64, M=1024, H=16, K=128 | 18538.4 | 20376.2
b16 B=64, M=1024, H=16, K=128 | 18576.9 | 20931.0
Times are in microseconds (us).
[ attention backward (attn_bias=<class 'xformers.ops.memory_efficient_attention.LowerTriangularMask'>) ]
| optimized | vanilla
1 threads: --------------------------------------------------
f16 B=384, M=197, H=1, K=64 | 225.1 | 652.1
b16 B=384, M=197, H=1, K=64 | 227.0 | 652.6
f16 B=1024, M=197, H=1, K=64 | 542.6 | 1647.2
b16 B=1024, M=197, H=1, K=64 | 546.3 | 1647.3
f16 B=32, M=197, H=16, K=64 | 282.4 | 885.0
b16 B=32, M=197, H=16, K=64 | 285.3 | 886.9
f16 B=32, M=197, H=16, K=128 | 521.3 | 1345.3
b16 B=32, M=197, H=16, K=128 | 521.9 | 1346.8
f16 B=16, M=197, H=16, K=64 | 243.5 | 485.8
b16 B=16, M=197, H=16, K=64 | 215.6 | 487.3
f16 B=16, M=197, H=16, K=128 | 298.9 | 714.0
b16 B=16, M=197, H=16, K=128 | 299.4 | 715.9
f16 B=1, M=4096, H=160, K=128 | 31532.4 | 38844.2
b16 B=1, M=4096, H=160, K=128 | 31579.6 | 39971.8
f16 B=2, M=4096, H=160, K=128 | 48359.2 | 78614.4
b16 B=2, M=4096, H=160, K=128 | 48379.2 | 80423.5
f16 B=1, M=8192, H=160, K=128 | 121826.9 |
b16 B=1, M=8192, H=160, K=128 | 121908.8 |
f16 B=2, M=8192, H=160, K=128 | 186970.4 |
b16 B=2, M=8192, H=160, K=128 | 186947.6 |
f16 B=1024, M=82, H=8, K=64 | 1513.1 | 3612.9
b16 B=1024, M=82, H=8, K=64 | 1524.6 | 3789.6
f16 B=150, M=256, H=16, K=64 | 1513.9 | 3821.8
b16 B=150, M=256, H=16, K=64 | 1530.0 | 3863.2
f16 B=64, M=256, H=12, K=64 | 529.9 | 1257.3
b16 B=64, M=256, H=12, K=64 | 533.3 | 1272.5
f16 B=256, M=4096, H=16, K=64 | 241098.0 |
b16 B=256, M=4096, H=16, K=64 | 241300.4 |
f16 B=8, M=2048, H=20, K=128 | 8338.3 | 10544.2
b16 B=8, M=2048, H=20, K=128 | 8347.7 | 10792.9
f16 B=16, M=128, H=16, K=16 | 197.3 | 328.7
b16 B=16, M=128, H=16, K=16 | 215.8 | 329.7
f16 B=16, M=128, H=16, K=32 | 196.4 | 308.2
b16 B=16, M=128, H=16, K=32 | 214.5 | 332.4
f16 B=16, M=128, H=16, K=64 | 221.3 | 302.9
b16 B=16, M=128, H=16, K=64 | 195.7 | 335.6
f16 B=16, M=128, H=16, K=128 | 216.6 | 324.5
b16 B=16, M=128, H=16, K=128 | 217.9 | 321.4
f16 B=16, M=512, H=16, K=16 | 261.4 | 986.3
b16 B=16, M=512, H=16, K=16 | 263.7 | 1078.2
f16 B=16, M=512, H=16, K=32 | 338.6 | 1089.9
b16 B=16, M=512, H=16, K=32 | 341.3 | 1178.4
f16 B=16, M=512, H=16, K=64 | 519.7 | 1286.0
b16 B=16, M=512, H=16, K=64 | 523.9 | 1306.3
f16 B=16, M=512, H=16, K=128 | 1036.2 | 1717.5
b16 B=16, M=512, H=16, K=128 | 1039.4 | 1744.6
f16 B=16, M=1024, H=16, K=16 | 790.6 | 3531.1
b16 B=16, M=1024, H=16, K=16 | 793.3 | 3954.9
f16 B=16, M=1024, H=16, K=32 | 1023.9 | 3736.3
b16 B=16, M=1024, H=16, K=32 | 1023.6 | 4155.7
f16 B=16, M=1024, H=16, K=64 | 1565.6 | 4279.1
b16 B=16, M=1024, H=16, K=64 | 1572.5 | 4350.4
f16 B=16, M=1024, H=16, K=128 | 3437.1 | 5117.6
b16 B=16, M=1024, H=16, K=128 | 3449.8 | 5254.7
f16 B=64, M=128, H=16, K=16 | 199.3 | 367.9
b16 B=64, M=128, H=16, K=16 | 196.3 | 376.1
f16 B=64, M=128, H=16, K=32 | 206.7 | 465.6
b16 B=64, M=128, H=16, K=32 | 214.1 | 471.7
f16 B=64, M=128, H=16, K=64 | 286.8 | 683.6
b16 B=64, M=128, H=16, K=64 | 289.0 | 685.6
f16 B=64, M=128, H=16, K=128 | 512.0 | 1075.4
b16 B=64, M=128, H=16, K=128 | 514.3 | 1078.1
f16 B=64, M=512, H=16, K=16 | 924.1 | 3713.1
b16 B=64, M=512, H=16, K=16 | 929.9 | 4090.5
f16 B=64, M=512, H=16, K=32 | 1180.3 | 4138.9
b16 B=64, M=512, H=16, K=32 | 1186.6 | 4505.3
f16 B=64, M=512, H=16, K=64 | 1761.6 | 4914.9
b16 B=64, M=512, H=16, K=64 | 1780.9 | 5004.2
f16 B=64, M=512, H=16, K=128 | 3586.7 | 6644.1
b16 B=64, M=512, H=16, K=128 | 3601.9 | 6750.4
f16 B=64, M=1024, H=16, K=16 | 2859.7 | 13877.5
b16 B=64, M=1024, H=16, K=16 | 2863.2 | 15563.7
f16 B=64, M=1024, H=16, K=32 | 3587.1 | 14759.8
b16 B=64, M=1024, H=16, K=32 | 3587.2 | 16462.9
f16 B=64, M=1024, H=16, K=64 | 5363.7 | 16889.4
b16 B=64, M=1024, H=16, K=64 | 5384.1 | 17213.8
f16 B=64, M=1024, H=16, K=128 | 11821.0 | 20391.7
b16 B=64, M=1024, H=16, K=128 | 11861.3 | 20907.0
Times are in microseconds (us).
@fmassa Thanks a lot for the helpful comments and for taking another pass! I've updated with related changes. Okay to merge for now?
Codecov Report
Base: 89.79% // Head: 89.06% // Decreases project coverage by -0.73%
:warning:
Coverage data is based on head (
5724663
) compared to base (71205ec
). Patch coverage: 48.86% of modified lines in pull request are covered.
Additional details and impacted files
@@ Coverage Diff @@
## main #479 +/- ##
==========================================
- Coverage 89.79% 89.06% -0.74%
==========================================
Files 80 80
Lines 4839 4927 +88
==========================================
+ Hits 4345 4388 +43
- Misses 494 539 +45
Flag | Coverage Δ | |
---|---|---|
Python | 89.06% <48.86%> (-0.74%) |
:arrow_down: |
Flags with carried forward coverage won't be shown. Click here to find out more.
Impacted Files | Coverage Δ | |
---|---|---|
xformers/info.py | 0.00% <ø> (ø) |
|
xformers/ops/__init__.py | 82.35% <ø> (ø) |
|
xformers/ops/memory_efficient_attention.py | 78.05% <48.86%> (-6.66%) |
:arrow_down: |
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.
:umbrella: View full report at Codecov.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.