xformers icon indicating copy to clipboard operation
xformers copied to clipboard

Add Triton Flash Attention

Open dianaml0 opened this issue 1 year ago • 6 comments

What does this PR do?

Adds Triton Flash Attention

Performance Compared to Vanilla
[--------- attention (attn_bias=<class 'NoneType'>) --------]
                                     |  optimized  |  vanilla
1 threads: --------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |     1260.4  |    372.9
      b16 B=384, M=197, H=1, K=88    |     1269.9  |    375.3
      f16 B=384, M=197, H=1, K=80    |      146.7  |    344.5
      b16 B=384, M=197, H=1, K=80    |      149.2  |    346.6
      f16 B=384, M=197, H=1, K=64    |       92.0  |    293.2
      b16 B=384, M=197, H=1, K=64    |       94.2  |    295.1
      f16 B=1024, M=197, H=1, K=88   |     3242.3  |    938.8
      b16 B=1024, M=197, H=1, K=88   |     3264.9  |    945.1
      f16 B=1024, M=197, H=1, K=80   |      348.7  |    864.5
      b16 B=1024, M=197, H=1, K=80   |      354.3  |    871.5
      f16 B=1024, M=197, H=1, K=64   |      213.7  |    729.9
      b16 B=1024, M=197, H=1, K=64   |      222.1  |    735.5
      f16 B=512, M=197, H=1, K=80    |      185.4  |    448.4
      b16 B=512, M=197, H=1, K=80    |      188.5  |    451.6
      f16 B=32, M=197, H=16, K=80    |      193.8  |    547.9
      b16 B=32, M=197, H=16, K=80    |      193.7  |    550.6
      f16 B=32, M=197, H=16, K=64    |      114.4  |    467.1
      b16 B=32, M=197, H=16, K=64    |      120.8  |    469.2
      f16 B=32, M=197, H=16, K=128   |      193.8  |    717.4
      b16 B=32, M=197, H=16, K=128   |      195.1  |    720.3
      f16 B=256, M=197, H=1, K=88    |      868.4  |    262.1
      b16 B=256, M=197, H=1, K=88    |      869.9  |    262.2
      f16 B=16, M=197, H=16, K=88    |      869.8  |    317.7
      b16 B=16, M=197, H=16, K=88    |      872.4  |    319.0
      f16 B=16, M=197, H=16, K=64    |       88.6  |    257.7
      b16 B=16, M=197, H=16, K=64    |       89.6  |    259.8
      f16 B=16, M=197, H=16, K=128   |      101.5  |    385.8
      b16 B=16, M=197, H=16, K=128   |      102.2  |    387.1
      f16 B=1, M=4096, H=160, K=128  |     9807.3  |  20343.8
      b16 B=1, M=4096, H=160, K=128  |    10034.8  |  21406.5
      f16 B=2, M=4096, H=160, K=128  |    19298.5  |  42595.1
      b16 B=2, M=4096, H=160, K=128  |    19841.7  |  44399.3
      f16 B=1, M=8192, H=160, K=128  |    37266.9  |  88426.6
      b16 B=1, M=8192, H=160, K=128  |    38570.7  |  87103.1
      f16 B=2, M=8192, H=160, K=128  |    74084.0  |
      b16 B=2, M=8192, H=160, K=128  |    76617.8  |
      f16 B=1024, M=82, H=8, K=64    |      461.1  |   1767.3
      b16 B=1024, M=82, H=8, K=64    |      479.9  |   1862.5
      f16 B=150, M=256, H=16, K=64   |      382.5  |   1725.6
      b16 B=150, M=256, H=16, K=64   |      420.9  |   1760.7
      f16 B=64, M=256, H=12, K=64    |      131.7  |    587.0
      b16 B=64, M=256, H=12, K=64    |      145.7  |    597.4
      f16 B=1, M=4096, H=16, K=40    |    30705.0  |   1937.8
      b16 B=1, M=4096, H=16, K=40    |    30832.3  |   1990.1
      f16 B=1, M=16384, H=16, K=40   |   423123.6  |  28845.9
      b16 B=1, M=16384, H=16, K=40   |   420611.6  |  30167.0
      f16 B=256, M=4096, H=16, K=64  |   118530.4  |
      b16 B=256, M=4096, H=16, K=64  |   132843.9  |
      f16 B=8, M=2048, H=20, K=128   |     2692.1  |   5704.5
      b16 B=8, M=2048, H=20, K=128   |     2743.9  |   6044.7
      f16 B=16, M=128, H=16, K=16    |       90.2  |    135.5
      b16 B=16, M=128, H=16, K=16    |       87.0  |    136.7
      f16 B=16, M=128, H=16, K=32    |       88.4  |    137.5
      b16 B=16, M=128, H=16, K=32    |       91.0  |    138.1
      f16 B=16, M=128, H=16, K=64    |       89.1  |    137.3
      b16 B=16, M=128, H=16, K=64    |       90.3  |    137.7
      f16 B=16, M=128, H=16, K=128   |       91.6  |    139.6
      b16 B=16, M=128, H=16, K=128   |       90.6  |    140.7
      f16 B=16, M=512, H=16, K=16    |       89.1  |    461.4
      b16 B=16, M=512, H=16, K=16    |       99.3  |    553.7
      f16 B=16, M=512, H=16, K=32    |      105.6  |    512.1
      b16 B=16, M=512, H=16, K=32    |      123.2  |    594.4
      f16 B=16, M=512, H=16, K=64    |      152.1  |    595.9
      b16 B=16, M=512, H=16, K=64    |      169.4  |    613.0
      f16 B=16, M=512, H=16, K=128   |      319.0  |    786.8
      b16 B=16, M=512, H=16, K=128   |      328.0  |    804.9
      f16 B=16, M=1024, H=16, K=16   |      292.1  |   1642.8
      b16 B=16, M=1024, H=16, K=16   |      385.7  |   2025.9
      f16 B=16, M=1024, H=16, K=32   |      369.5  |   1732.0
      b16 B=16, M=1024, H=16, K=32   |      428.2  |   2127.2
      f16 B=16, M=1024, H=16, K=64   |      518.4  |   2033.7
      b16 B=16, M=1024, H=16, K=64   |      579.9  |   2064.6
      f16 B=16, M=1024, H=16, K=128  |     1097.7  |   2410.3
      b16 B=16, M=1024, H=16, K=128  |     1135.3  |   2490.3
      f16 B=64, M=128, H=16, K=16    |       87.5  |    183.0
      b16 B=64, M=128, H=16, K=16    |       89.6  |    184.9
      f16 B=64, M=128, H=16, K=32    |       89.2  |    224.6
      b16 B=64, M=128, H=16, K=32    |       88.2  |    225.5
      f16 B=64, M=128, H=16, K=64    |       87.5  |    321.5
      b16 B=64, M=128, H=16, K=64    |       87.2  |    322.9
      f16 B=64, M=128, H=16, K=128   |      127.0  |    484.7
      b16 B=64, M=128, H=16, K=128   |      129.4  |    486.1
      f16 B=64, M=512, H=16, K=16    |      307.5  |   1727.1
      b16 B=64, M=512, H=16, K=16    |      403.5  |   2094.0
      f16 B=64, M=512, H=16, K=32    |      387.2  |   1888.1
      b16 B=64, M=512, H=16, K=32    |      450.6  |   2250.3
      f16 B=64, M=512, H=16, K=64    |      557.9  |   2233.9
      b16 B=64, M=512, H=16, K=64    |      620.4  |   2294.9
      f16 B=64, M=512, H=16, K=128   |     1205.0  |   3008.3
      b16 B=64, M=512, H=16, K=128   |     1235.5  |   3072.8
      f16 B=64, M=1024, H=16, K=16   |     1101.8  |   6422.7
      b16 B=64, M=1024, H=16, K=16   |     1473.2  |   8024.6
      f16 B=64, M=1024, H=16, K=32   |     1409.8  |   6744.7
      b16 B=64, M=1024, H=16, K=32   |     1667.6  |   8387.5
      f16 B=64, M=1024, H=16, K=64   |     2022.1  |   7966.1
      b16 B=64, M=1024, H=16, K=64   |     2265.4  |   8127.2
      f16 B=64, M=1024, H=16, K=128  |     4282.5  |   9524.9
      b16 B=64, M=1024, H=16, K=128  |     4413.0  |   9842.0

Times are in microseconds (us).
[ attention (attn_bias=<class 'xformers.ops.LowerTriangularMask'>) ]
                                     |  optimized  |  vanilla
1 threads: ---------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |     3701.7  |     447.0
      b16 B=384, M=197, H=1, K=88    |     1178.3  |     452.6
      f16 B=384, M=197, H=1, K=80    |      120.2  |     421.6
      b16 B=384, M=197, H=1, K=80    |      122.5  |     427.5
      f16 B=384, M=197, H=1, K=64    |       87.3  |     370.6
      b16 B=384, M=197, H=1, K=64    |       90.1  |     376.4
      f16 B=1024, M=197, H=1, K=88   |     9586.4  |    1124.4
      b16 B=1024, M=197, H=1, K=88   |     2981.6  |    1139.9
      f16 B=1024, M=197, H=1, K=80   |      294.0  |    1059.9
      b16 B=1024, M=197, H=1, K=80   |      300.3  |    1075.3
      f16 B=1024, M=197, H=1, K=64   |      187.5  |     925.9
      b16 B=1024, M=197, H=1, K=64   |      197.2  |     940.5
      f16 B=512, M=197, H=1, K=80    |      153.5  |     549.2
      b16 B=512, M=197, H=1, K=80    |      156.6  |     557.4
      f16 B=32, M=197, H=16, K=80    |      157.4  |     647.5
      b16 B=32, M=197, H=16, K=80    |      160.6  |     654.6
      f16 B=32, M=197, H=16, K=64    |      103.8  |     565.6
      b16 B=32, M=197, H=16, K=64    |      109.2  |     573.3
      f16 B=32, M=197, H=16, K=128   |      157.7  |     810.9
      b16 B=32, M=197, H=16, K=128   |      159.7  |     818.4
      f16 B=256, M=197, H=1, K=88    |     2551.1  |     310.8
      b16 B=256, M=197, H=1, K=88    |      813.3  |     314.9
      f16 B=16, M=197, H=16, K=88    |     2554.5  |     367.9
      b16 B=16, M=197, H=16, K=88    |      806.5  |     372.0
      f16 B=16, M=197, H=16, K=64    |       90.3  |     307.1
      b16 B=16, M=197, H=16, K=64    |       89.6  |     310.9
      f16 B=16, M=197, H=16, K=128   |       87.8  |     435.4
      b16 B=16, M=197, H=16, K=128   |       91.6  |     438.6
      f16 B=1, M=4096, H=160, K=128  |     5176.5  |   37308.5
      b16 B=1, M=4096, H=160, K=128  |     5330.5  |   37818.1
      f16 B=2, M=4096, H=160, K=128  |    10225.6  |   76322.3
      b16 B=2, M=4096, H=160, K=128  |    10547.6  |   77302.1
      f16 B=1, M=8192, H=160, K=128  |    19289.3  |  152433.6
      b16 B=1, M=8192, H=160, K=128  |    19922.2  |  148602.0
      f16 B=2, M=8192, H=160, K=128  |    38313.9  |
      b16 B=2, M=8192, H=160, K=128  |    39610.6  |
      f16 B=1024, M=82, H=8, K=64    |      488.4  |    1979.9
      b16 B=1024, M=82, H=8, K=64    |      515.8  |    2085.1
      f16 B=150, M=256, H=16, K=64   |      333.7  |    2401.0
      b16 B=150, M=256, H=16, K=64   |      349.7  |    2446.2
      f16 B=64, M=256, H=12, K=64    |      118.8  |     805.3
      b16 B=64, M=256, H=12, K=64    |      124.8  |     819.5
      f16 B=1, M=4096, H=16, K=40    |    15004.4  |    3428.4
      b16 B=1, M=4096, H=16, K=40    |    14890.7  |    3465.7
      f16 B=1, M=16384, H=16, K=40   |   221390.1  |   55383.6
      b16 B=1, M=16384, H=16, K=40   |   218101.2  |   55427.0
      f16 B=256, M=4096, H=16, K=64  |    67757.8  |
      b16 B=256, M=4096, H=16, K=64  |    74047.4  |
      f16 B=8, M=2048, H=20, K=128   |     1487.8  |    9352.9
      b16 B=8, M=2048, H=20, K=128   |     1534.3  |    9539.5
      f16 B=16, M=128, H=16, K=16    |       92.2  |     146.0
      b16 B=16, M=128, H=16, K=16    |       87.0  |     143.8
      f16 B=16, M=128, H=16, K=32    |       87.2  |     144.5
      b16 B=16, M=128, H=16, K=32    |       89.3  |     145.7
      f16 B=16, M=128, H=16, K=64    |       90.6  |     144.4
      b16 B=16, M=128, H=16, K=64    |       90.3  |     141.9
      f16 B=16, M=128, H=16, K=128   |       87.2  |     160.0
      b16 B=16, M=128, H=16, K=128   |       87.2  |     161.7
      f16 B=16, M=512, H=16, K=16    |       90.4  |     715.1
      b16 B=16, M=512, H=16, K=16    |       91.0  |     792.1
      f16 B=16, M=512, H=16, K=32    |       88.5  |     766.6
      b16 B=16, M=512, H=16, K=32    |       98.1  |     832.2
      f16 B=16, M=512, H=16, K=64    |      120.3  |     881.1
      b16 B=16, M=512, H=16, K=64    |      129.0  |     910.1
      f16 B=16, M=512, H=16, K=128   |      225.5  |    1066.3
      b16 B=16, M=512, H=16, K=128   |      231.9  |    1095.5
      f16 B=16, M=1024, H=16, K=16   |      196.5  |    2658.0
      b16 B=16, M=1024, H=16, K=16   |      253.2  |    3042.2
      f16 B=16, M=1024, H=16, K=32   |      244.3  |    2749.6
      b16 B=16, M=1024, H=16, K=32   |      291.8  |    3121.5
      f16 B=16, M=1024, H=16, K=64   |      355.3  |    2963.7
      b16 B=16, M=1024, H=16, K=64   |      384.7  |    3287.5
      f16 B=16, M=1024, H=16, K=128  |      683.0  |    3534.2
      b16 B=16, M=1024, H=16, K=128  |      705.2  |    3711.2
      f16 B=64, M=128, H=16, K=16    |       87.6  |     245.8
      b16 B=64, M=128, H=16, K=16    |       89.9  |     251.7
      f16 B=64, M=128, H=16, K=32    |       88.4  |     294.2
      b16 B=64, M=128, H=16, K=32    |       90.7  |     298.7
      f16 B=64, M=128, H=16, K=64    |       92.0  |     391.1
      b16 B=64, M=128, H=16, K=64    |       88.6  |     395.9
      f16 B=64, M=128, H=16, K=128   |      131.0  |     557.0
      b16 B=64, M=128, H=16, K=128   |      133.2  |     562.1
      f16 B=64, M=512, H=16, K=16    |      227.6  |    2704.8
      b16 B=64, M=512, H=16, K=16    |      289.3  |    3019.3
      f16 B=64, M=512, H=16, K=32    |      285.1  |    2891.7
      b16 B=64, M=512, H=16, K=32    |      336.8  |    3163.4
      f16 B=64, M=512, H=16, K=64    |      414.3  |    3351.7
      b16 B=64, M=512, H=16, K=64    |      444.2  |    3475.0
      f16 B=64, M=512, H=16, K=128   |      829.5  |    4097.1
      b16 B=64, M=512, H=16, K=128   |      854.6  |    4208.8
      f16 B=64, M=1024, H=16, K=16   |      722.1  |   10479.3
      b16 B=64, M=1024, H=16, K=16   |      931.6  |   12030.5
      f16 B=64, M=1024, H=16, K=32   |      900.9  |   10856.5
      b16 B=64, M=1024, H=16, K=32   |     1075.2  |   12361.2
      f16 B=64, M=1024, H=16, K=64   |     1313.1  |   11684.0
      b16 B=64, M=1024, H=16, K=64   |     1424.4  |   13002.0
      f16 B=64, M=1024, H=16, K=128  |     2610.0  |   13969.2
      b16 B=64, M=1024, H=16, K=128  |     2697.7  |   14687.4

Times are in microseconds (us).
[---- attention backward (attn_bias=<class 'NoneType'>) ----]
                                     |  optimized  |  vanilla
1 threads: --------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |     3702.8  |    820.9
      b16 B=384, M=197, H=1, K=88    |     3701.9  |    823.4
      f16 B=384, M=197, H=1, K=80    |      761.9  |    768.3
      b16 B=384, M=197, H=1, K=80    |      762.5  |    769.8
      f16 B=384, M=197, H=1, K=64    |      592.9  |    651.8
      b16 B=384, M=197, H=1, K=64    |      372.7  |    652.9
      f16 B=1024, M=197, H=1, K=88   |     9107.4  |   2103.4
      b16 B=1024, M=197, H=1, K=88   |     9115.3  |   2105.4
      f16 B=1024, M=197, H=1, K=80   |     1882.7  |   1958.4
      b16 B=1024, M=197, H=1, K=80   |     1891.6  |   1958.3
      f16 B=1024, M=197, H=1, K=64   |      916.6  |   1649.8
      b16 B=1024, M=197, H=1, K=64   |      924.4  |   1649.1
      f16 B=512, M=197, H=1, K=80    |      987.1  |    994.2
      b16 B=512, M=197, H=1, K=80    |      990.8  |    997.0
      f16 B=32, M=197, H=16, K=80    |     1037.3  |   1034.2
      b16 B=32, M=197, H=16, K=80    |     1040.1  |   1035.8
      f16 B=32, M=197, H=16, K=64    |      486.2  |    885.3
      b16 B=32, M=197, H=16, K=64    |      488.7  |    887.6
      f16 B=32, M=197, H=16, K=128   |     1165.0  |   1343.0
      b16 B=32, M=197, H=16, K=128   |     1168.3  |   1345.8
      f16 B=256, M=197, H=1, K=88    |     2360.9  |    576.3
      b16 B=256, M=197, H=1, K=88    |     2361.7  |    576.9
      f16 B=16, M=197, H=16, K=88    |     2377.4  |    599.7
      b16 B=16, M=197, H=16, K=88    |     2376.8  |    600.9
      f16 B=16, M=197, H=16, K=64    |      303.0  |    488.1
      b16 B=16, M=197, H=16, K=64    |      330.5  |    488.6
      f16 B=16, M=197, H=16, K=128   |      617.9  |    713.8
      b16 B=16, M=197, H=16, K=128   |      619.7  |    716.1
      f16 B=1, M=4096, H=160, K=128  |    41860.6  |  38927.7
      b16 B=1, M=4096, H=160, K=128  |    41971.4  |  39833.9
      f16 B=2, M=4096, H=160, K=128  |    83623.8  |  78708.1
      b16 B=2, M=4096, H=160, K=128  |    84009.4  |  80592.3
      f16 B=1, M=8192, H=160, K=128  |   160725.6  |
      b16 B=1, M=8192, H=160, K=128  |   161001.2  |
      f16 B=2, M=8192, H=160, K=128  |   321049.6  |
      b16 B=2, M=8192, H=160, K=128  |   321800.6  |
      f16 B=1024, M=82, H=8, K=64    |     2608.7  |   3609.7
      b16 B=1024, M=82, H=8, K=64    |     2525.2  |   3783.9
      f16 B=150, M=256, H=16, K=64   |     2552.5  |   3809.6
      b16 B=150, M=256, H=16, K=64   |     2464.5  |   3838.1
      f16 B=64, M=256, H=12, K=64    |      827.8  |   1257.5
      b16 B=64, M=256, H=12, K=64    |      835.1  |   1269.2
      f16 B=1, M=4096, H=16, K=40    |    43085.7  |   3556.9
      b16 B=1, M=4096, H=16, K=40    |    42898.9  |   3595.8
      f16 B=1, M=16384, H=16, K=40   |   664283.3  |  53870.2
      b16 B=1, M=16384, H=16, K=40   |   664702.6  |  54308.0
      f16 B=256, M=4096, H=16, K=64  |   503458.9  |
      b16 B=256, M=4096, H=16, K=64  |   503461.7  |
      f16 B=8, M=2048, H=20, K=128   |    11442.9  |  10536.5
      b16 B=8, M=2048, H=20, K=128   |    11362.9  |  10790.8
      f16 B=16, M=128, H=16, K=16    |      349.5  |    311.2
      b16 B=16, M=128, H=16, K=16    |      303.8  |    314.0
      f16 B=16, M=128, H=16, K=32    |      316.6  |    340.6
      b16 B=16, M=128, H=16, K=32    |      328.7  |    337.9
      f16 B=16, M=128, H=16, K=64    |      318.3  |    333.9
      b16 B=16, M=128, H=16, K=64    |      332.7  |    335.4
      f16 B=16, M=128, H=16, K=128   |      452.5  |    331.2
      b16 B=16, M=128, H=16, K=128   |      329.1  |    335.9
      f16 B=16, M=512, H=16, K=16    |      527.1  |    982.9
      b16 B=16, M=512, H=16, K=16    |      322.8  |   1078.9
      f16 B=16, M=512, H=16, K=32    |      680.9  |   1090.2
      b16 B=16, M=512, H=16, K=32    |      487.4  |   1179.3
      f16 B=16, M=512, H=16, K=64    |     1016.7  |   1276.3
      b16 B=16, M=512, H=16, K=64    |      845.2  |   1295.1
      f16 B=16, M=512, H=16, K=128   |     1891.7  |   1712.6
      b16 B=16, M=512, H=16, K=128   |     1766.5  |   1744.0
      f16 B=16, M=1024, H=16, K=16   |     1257.4  |   3532.6
      b16 B=16, M=1024, H=16, K=16   |     1066.4  |   3953.2
      f16 B=16, M=1024, H=16, K=32   |     1770.3  |   3731.3
      b16 B=16, M=1024, H=16, K=32   |     1615.7  |   4158.6
      f16 B=16, M=1024, H=16, K=64   |     2686.2  |   4281.1
      b16 B=16, M=1024, H=16, K=64   |     2545.2  |   4348.9
      f16 B=16, M=1024, H=16, K=128  |     5407.5  |   5109.9
      b16 B=16, M=1024, H=16, K=128  |     5312.2  |   5248.8
      f16 B=64, M=128, H=16, K=16    |      304.3  |    365.2
      b16 B=64, M=128, H=16, K=16    |      325.1  |    372.8
      f16 B=64, M=128, H=16, K=32    |      304.3  |    465.8
      b16 B=64, M=128, H=16, K=32    |      301.5  |    472.4
      f16 B=64, M=128, H=16, K=64    |      464.3  |    679.8
      b16 B=64, M=128, H=16, K=64    |      469.2  |    681.4
      f16 B=64, M=128, H=16, K=128   |      980.1  |   1076.0
      b16 B=64, M=128, H=16, K=128   |      987.5  |   1079.2
      f16 B=64, M=512, H=16, K=16    |     1109.7  |   3705.7
      b16 B=64, M=512, H=16, K=16    |     1111.9  |   4078.7
      f16 B=64, M=512, H=16, K=32    |     1678.8  |   4135.2
      b16 B=64, M=512, H=16, K=32    |     1684.5  |   4504.2
      f16 B=64, M=512, H=16, K=64    |     2982.7  |   4906.5
      b16 B=64, M=512, H=16, K=64    |     3001.9  |   4992.0
      f16 B=64, M=512, H=16, K=128   |     6641.2  |   6644.6
      b16 B=64, M=512, H=16, K=128   |     6697.3  |   6756.1
      f16 B=64, M=1024, H=16, K=16   |     3610.9  |  13882.0
      b16 B=64, M=1024, H=16, K=16   |     3614.8  |  15555.3
      f16 B=64, M=1024, H=16, K=32   |     6138.0  |  14767.7
      b16 B=64, M=1024, H=16, K=32   |     6147.9  |  16450.9
      f16 B=64, M=1024, H=16, K=64   |     9813.8  |  16866.6
      b16 B=64, M=1024, H=16, K=64   |     9848.6  |  17160.6
      f16 B=64, M=1024, H=16, K=128  |    20556.5  |  20371.7
      b16 B=64, M=1024, H=16, K=128  |    20674.3  |  20916.7

Times are in microseconds (us).
[ attention backward (attn_bias=<class 'xformers.ops.LowerTriangularMask'>) ]
                                     |  optimized  |  vanilla
1 threads: --------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |     9480.6  |    822.5
      b16 B=384, M=197, H=1, K=88    |     9355.1  |    823.8
      f16 B=384, M=197, H=1, K=80    |      669.6  |    768.6
      b16 B=384, M=197, H=1, K=80    |      674.6  |    770.2
      f16 B=384, M=197, H=1, K=64    |      542.3  |    651.9
      b16 B=384, M=197, H=1, K=64    |      351.1  |    653.8
      f16 B=1024, M=197, H=1, K=88   |    23565.1  |   2105.1
      b16 B=1024, M=197, H=1, K=88   |    23556.3  |   2105.4
      f16 B=1024, M=197, H=1, K=80   |     1646.1  |   1961.7
      b16 B=1024, M=197, H=1, K=80   |     1651.9  |   1958.3
      f16 B=1024, M=197, H=1, K=64   |      861.8  |   1649.0
      b16 B=1024, M=197, H=1, K=64   |      869.0  |   1648.1
      f16 B=512, M=197, H=1, K=80    |      863.6  |    995.4
      b16 B=512, M=197, H=1, K=80    |      866.2  |    997.1
      f16 B=32, M=197, H=16, K=80    |      867.8  |   1034.4
      b16 B=32, M=197, H=16, K=80    |      869.5  |   1036.2
      f16 B=32, M=197, H=16, K=64    |      456.3  |    883.7
      b16 B=32, M=197, H=16, K=64    |      458.7  |    885.6
      f16 B=32, M=197, H=16, K=128   |     1042.0  |   1340.5
      b16 B=32, M=197, H=16, K=128   |     1045.1  |   1343.6
      f16 B=256, M=197, H=1, K=88    |     6342.1  |    577.0
      b16 B=256, M=197, H=1, K=88    |     6344.3  |    578.3
      f16 B=16, M=197, H=16, K=88    |     6335.6  |    600.3
      b16 B=16, M=197, H=16, K=88    |     6341.5  |    601.6
      f16 B=16, M=197, H=16, K=64    |      330.3  |    488.4
      b16 B=16, M=197, H=16, K=64    |      325.5  |    490.0
      f16 B=16, M=197, H=16, K=128   |      557.4  |    715.2
      b16 B=16, M=197, H=16, K=128   |      557.8  |    717.4
      f16 B=1, M=4096, H=160, K=128  |    26152.2  |  38914.3
      b16 B=1, M=4096, H=160, K=128  |    26166.4  |  39924.4
      f16 B=2, M=4096, H=160, K=128  |    51711.2  |  78726.9
      b16 B=2, M=4096, H=160, K=128  |    51925.3  |  80711.7
      f16 B=1, M=8192, H=160, K=128  |    92696.3  |
      b16 B=1, M=8192, H=160, K=128  |    92960.2  |
      f16 B=2, M=8192, H=160, K=128  |   184624.7  |
      b16 B=2, M=8192, H=160, K=128  |   185330.2  |
      f16 B=1024, M=82, H=8, K=64    |     2724.8  |   3608.8
      b16 B=1024, M=82, H=8, K=64    |     2642.1  |   3784.4
      f16 B=150, M=256, H=16, K=64   |     2379.5  |   3803.7
      b16 B=150, M=256, H=16, K=64   |     2287.7  |   3832.2
      f16 B=64, M=256, H=12, K=64    |      774.1  |   1255.6
      b16 B=64, M=256, H=12, K=64    |      781.0  |   1270.5
      f16 B=1, M=4096, H=16, K=40    |     6799.9  |   3563.7
      b16 B=1, M=4096, H=16, K=40    |     6624.0  |   3598.3
      f16 B=1, M=16384, H=16, K=40   |    94594.6  |  53902.5
      b16 B=1, M=16384, H=16, K=40   |    94692.4  |  54354.3
      f16 B=256, M=4096, H=16, K=64  |   274333.4  |
      b16 B=256, M=4096, H=16, K=64  |   274964.5  |
      f16 B=8, M=2048, H=20, K=128   |     7844.4  |  10545.5
      b16 B=8, M=2048, H=20, K=128   |     7764.3  |  10795.4
      f16 B=16, M=128, H=16, K=16    |      334.9  |    293.7
      b16 B=16, M=128, H=16, K=16    |      326.3  |    314.7
      f16 B=16, M=128, H=16, K=32    |      354.5  |    319.1
      b16 B=16, M=128, H=16, K=32    |      304.6  |    291.4
      f16 B=16, M=128, H=16, K=64    |      350.7  |    312.6
      b16 B=16, M=128, H=16, K=64    |      331.6  |    301.8
      f16 B=16, M=128, H=16, K=128   |      486.3  |    313.5
      b16 B=16, M=128, H=16, K=128   |      326.8  |    314.4
      f16 B=16, M=512, H=16, K=16    |      456.9  |    986.5
      b16 B=16, M=512, H=16, K=16    |      306.0  |   1080.0
      f16 B=16, M=512, H=16, K=32    |      566.6  |   1091.4
      b16 B=16, M=512, H=16, K=32    |      393.9  |   1179.2
      f16 B=16, M=512, H=16, K=64    |      841.5  |   1276.5
      b16 B=16, M=512, H=16, K=64    |      669.4  |   1294.9
      f16 B=16, M=512, H=16, K=128   |     1590.8  |   1713.6
      b16 B=16, M=512, H=16, K=128   |     1486.2  |   1740.0
      f16 B=16, M=1024, H=16, K=16   |      858.6  |   3529.7
      b16 B=16, M=1024, H=16, K=16   |      686.6  |   3951.6
      f16 B=16, M=1024, H=16, K=32   |     1233.9  |   3733.7
      b16 B=16, M=1024, H=16, K=32   |     1059.3  |   4154.4
      f16 B=16, M=1024, H=16, K=64   |     1875.2  |   4281.8
      b16 B=16, M=1024, H=16, K=64   |     1754.2  |   4347.6
      f16 B=16, M=1024, H=16, K=128  |     4062.9  |   5105.1
      b16 B=16, M=1024, H=16, K=128  |     3983.0  |   5247.9
      f16 B=64, M=128, H=16, K=16    |      304.4  |    368.6
      b16 B=64, M=128, H=16, K=16    |      301.9  |    376.4
      f16 B=64, M=128, H=16, K=32    |      303.9  |    466.2
      b16 B=64, M=128, H=16, K=32    |      327.3  |    471.8
      f16 B=64, M=128, H=16, K=64    |      475.8  |    682.3
      b16 B=64, M=128, H=16, K=64    |      479.3  |    679.5
      f16 B=64, M=128, H=16, K=128   |     1041.7  |   1076.2
      b16 B=64, M=128, H=16, K=128   |     1047.8  |   1077.4
      f16 B=64, M=512, H=16, K=16    |      912.1  |   3704.3
      b16 B=64, M=512, H=16, K=16    |      917.2  |   4082.8
      f16 B=64, M=512, H=16, K=32    |     1458.2  |   4133.9
      b16 B=64, M=512, H=16, K=32    |     1461.2  |   4501.9
      f16 B=64, M=512, H=16, K=64    |     2466.8  |   4903.1
      b16 B=64, M=512, H=16, K=64    |     2492.7  |   4993.7
      f16 B=64, M=512, H=16, K=128   |     5545.0  |   6642.9
      b16 B=64, M=512, H=16, K=128   |     5595.6  |   6754.2
      f16 B=64, M=1024, H=16, K=16   |     2608.6  |  13864.4
      b16 B=64, M=1024, H=16, K=16   |     2610.5  |  15560.3
      f16 B=64, M=1024, H=16, K=32   |     4025.8  |  14755.8
      b16 B=64, M=1024, H=16, K=32   |     4032.1  |  16463.7
      f16 B=64, M=1024, H=16, K=64   |     6658.9  |  16880.7
      b16 B=64, M=1024, H=16, K=64   |     6722.0  |  17150.6
      f16 B=64, M=1024, H=16, K=128  |    15345.5  |  20371.0
      b16 B=64, M=1024, H=16, K=128  |    15438.5  |  20920.1

Times are in microseconds (us).
[ attention backward (attn_bias=<class 'xformers.ops.LowerTriangularMask'>) ]
                                     |  optimized  |  fctls_bflsh
1 threads: ------------------------------------------------------
      f16 B=384, M=197, H=1, K=64    |      542.1  |       241.0
      b16 B=384, M=197, H=1, K=64    |      344.0  |       236.7
      f16 B=1024, M=197, H=1, K=64   |      852.4  |       539.2
      b16 B=1024, M=197, H=1, K=64   |      857.6  |       544.6
      f16 B=32, M=197, H=16, K=64    |      449.0  |       280.6
      b16 B=32, M=197, H=16, K=64    |      451.6  |       283.2
      f16 B=32, M=197, H=16, K=128   |     1161.4  |       517.8
      b16 B=32, M=197, H=16, K=128   |     1030.0  |       518.8
      f16 B=16, M=197, H=16, K=64    |      312.1  |       214.0
      b16 B=16, M=197, H=16, K=64    |      309.8  |       213.7
      f16 B=16, M=197, H=16, K=128   |      547.4  |       296.6
      b16 B=16, M=197, H=16, K=128   |      548.1  |       298.1
      f16 B=1, M=4096, H=160, K=128  |    25831.9  |     31365.0
      b16 B=1, M=4096, H=160, K=128  |    25811.2  |     31382.5
      f16 B=2, M=4096, H=160, K=128  |    51084.6  |     48413.3
      b16 B=2, M=4096, H=160, K=128  |    51265.2  |     48371.8
      f16 B=1, M=8192, H=160, K=128  |    91356.3  |    121976.0
      b16 B=1, M=8192, H=160, K=128  |    91520.0  |    122016.8
      f16 B=2, M=8192, H=160, K=128  |   181984.4  |    187070.0
      b16 B=2, M=8192, H=160, K=128  |   182646.4  |    187302.9
      f16 B=1024, M=82, H=8, K=64    |     2703.5  |      1502.2
      b16 B=1024, M=82, H=8, K=64    |     2615.9  |      1511.7
      f16 B=150, M=256, H=16, K=64   |     2380.2  |      1505.7
      b16 B=150, M=256, H=16, K=64   |     2264.4  |      1521.6
      f16 B=64, M=256, H=12, K=64    |      768.4  |       526.4
      b16 B=64, M=256, H=12, K=64    |      774.2  |       530.1
      f16 B=8, M=2048, H=20, K=128   |     7798.2  |      8299.4
      b16 B=8, M=2048, H=20, K=128   |     7678.0  |      8323.0
      f16 B=16, M=128, H=16, K=16    |      322.9  |       212.5
      b16 B=16, M=128, H=16, K=16    |      311.7  |       192.5
      f16 B=16, M=128, H=16, K=32    |      359.6  |       212.4
      b16 B=16, M=128, H=16, K=32    |      333.6  |       194.5
      f16 B=16, M=128, H=16, K=64    |      330.3  |       196.0
      b16 B=16, M=128, H=16, K=64    |      314.8  |       196.9
      f16 B=16, M=128, H=16, K=128   |      489.5  |       215.7
      b16 B=16, M=128, H=16, K=128   |      343.9  |       215.5
      f16 B=16, M=512, H=16, K=16    |      463.5  |       261.1
      b16 B=16, M=512, H=16, K=16    |      318.1  |       263.3
      f16 B=16, M=512, H=16, K=32    |      573.5  |       337.3
      b16 B=16, M=512, H=16, K=32    |      391.6  |       339.6
      f16 B=16, M=512, H=16, K=64    |      829.6  |       517.0
      b16 B=16, M=512, H=16, K=64    |      666.4  |       521.2
      f16 B=16, M=512, H=16, K=128   |     1608.1  |      1034.4
      b16 B=16, M=512, H=16, K=128   |     1472.1  |      1032.4
      f16 B=16, M=1024, H=16, K=16   |      866.3  |       792.3
      b16 B=16, M=1024, H=16, K=16   |      684.4  |       793.2
      f16 B=16, M=1024, H=16, K=32   |     1269.0  |      1022.8
      b16 B=16, M=1024, H=16, K=32   |     1056.2  |      1023.4
      f16 B=16, M=1024, H=16, K=64   |     1885.2  |      1563.7
      b16 B=16, M=1024, H=16, K=64   |     1749.4  |      1564.6
      f16 B=16, M=1024, H=16, K=128  |     4052.5  |      3428.0
      b16 B=16, M=1024, H=16, K=128  |     3929.8  |      3432.8
      f16 B=64, M=128, H=16, K=16    |      313.7  |       218.0
      b16 B=64, M=128, H=16, K=16    |      336.9  |       217.6
      f16 B=64, M=128, H=16, K=32    |      317.4  |       212.5
      b16 B=64, M=128, H=16, K=32    |      318.1  |       208.0
      f16 B=64, M=128, H=16, K=64    |      470.3  |       284.8
      b16 B=64, M=128, H=16, K=64    |      474.6  |       286.9
      f16 B=64, M=128, H=16, K=128   |     1024.8  |       504.8
      b16 B=64, M=128, H=16, K=128   |     1030.2  |       507.4
      f16 B=64, M=512, H=16, K=16    |      909.5  |       924.7
      b16 B=64, M=512, H=16, K=16    |      913.6  |       930.5
      f16 B=64, M=512, H=16, K=32    |     1454.8  |      1176.4
      b16 B=64, M=512, H=16, K=32    |     1459.1  |      1181.8
      f16 B=64, M=512, H=16, K=64    |     2460.3  |      1752.1
      b16 B=64, M=512, H=16, K=64    |     2485.5  |      1773.0
      f16 B=64, M=512, H=16, K=128   |     5503.4  |      3564.6
      b16 B=64, M=512, H=16, K=128   |     5557.4  |      3592.5
      f16 B=64, M=1024, H=16, K=16   |     2599.1  |      2866.5
      b16 B=64, M=1024, H=16, K=16   |     2605.9  |      2868.9
      f16 B=64, M=1024, H=16, K=32   |     4017.3  |      3577.7
      b16 B=64, M=1024, H=16, K=32   |     4022.4  |      3592.0
      f16 B=64, M=1024, H=16, K=64   |     6648.5  |      5349.7
      b16 B=64, M=1024, H=16, K=64   |     6716.9  |      5374.7
      f16 B=64, M=1024, H=16, K=128  |    15206.5  |     11777.5
      b16 B=64, M=1024, H=16, K=128  |    15313.3  |     11814.0

Times are in microseconds (us).
Performance Compared to MemoryEfficientAttentionCutlassFwdFlashBwOp

[----------- attention (attn_bias=<class 'NoneType'>) ----------]
                                     |  optimized  |  fctls_bflsh
1 threads: ------------------------------------------------------
      f16 B=384, M=197, H=1, K=64    |       89.9  |        88.9
      b16 B=384, M=197, H=1, K=64    |       94.2  |        88.8
      f16 B=1024, M=197, H=1, K=64   |      213.5  |       215.1
      b16 B=1024, M=197, H=1, K=64   |      221.9  |       215.1
      f16 B=32, M=197, H=16, K=64    |      114.0  |       113.8
      b16 B=32, M=197, H=16, K=64    |      120.0  |       113.8
      f16 B=32, M=197, H=16, K=128   |      193.3  |       168.6
      b16 B=32, M=197, H=16, K=128   |      194.6  |       166.9
      f16 B=16, M=197, H=16, K=64    |       85.2  |        60.7
      b16 B=16, M=197, H=16, K=64    |       87.2  |        60.7
      f16 B=16, M=197, H=16, K=128   |      101.6  |        88.6
      b16 B=16, M=197, H=16, K=128   |      102.0  |        88.0
      f16 B=1, M=4096, H=160, K=128  |     9816.7  |     15435.3
      b16 B=1, M=4096, H=160, K=128  |    10046.1  |     15280.9
      f16 B=2, M=4096, H=160, K=128  |    19357.5  |     30760.0
      b16 B=2, M=4096, H=160, K=128  |    19869.8  |     30546.3
      f16 B=1, M=8192, H=160, K=128  |    37248.5  |     61367.7
      b16 B=1, M=8192, H=160, K=128  |    38471.5  |     61015.5
      f16 B=2, M=8192, H=160, K=128  |    73817.0  |    123123.0
      b16 B=2, M=8192, H=160, K=128  |    76761.6  |    121918.5
      f16 B=1024, M=82, H=8, K=64    |      459.6  |       455.0
      b16 B=1024, M=82, H=8, K=64    |      477.1  |       455.1
      f16 B=150, M=256, H=16, K=64   |      383.4  |       521.7
      b16 B=150, M=256, H=16, K=64   |      419.8  |       518.6
      f16 B=64, M=256, H=12, K=64    |      131.4  |       176.2
      b16 B=64, M=256, H=12, K=64    |      145.1  |       173.2
      f16 B=256, M=4096, H=16, K=64  |   118450.0  |    191728.6
      b16 B=256, M=4096, H=16, K=64  |   131960.6  |    190391.4
      f16 B=8, M=2048, H=20, K=128   |     2694.7  |      3380.7
      b16 B=8, M=2048, H=20, K=128   |     2738.6  |      3332.9
      f16 B=16, M=128, H=16, K=16    |       90.9  |        31.2
      b16 B=16, M=128, H=16, K=16    |       89.4  |        30.2
      f16 B=16, M=128, H=16, K=32    |       87.6  |        30.7
      b16 B=16, M=128, H=16, K=32    |       90.0  |        30.7
      f16 B=16, M=128, H=16, K=64    |       88.7  |        30.2
      b16 B=16, M=128, H=16, K=64    |       89.4  |        30.2
      f16 B=16, M=128, H=16, K=128   |       87.2  |        35.8
      b16 B=16, M=128, H=16, K=128   |       89.3  |        35.8
      f16 B=16, M=512, H=16, K=16    |       89.2  |       180.0
      b16 B=16, M=512, H=16, K=16    |       99.6  |       179.9
      f16 B=16, M=512, H=16, K=32    |      105.3  |       186.2
      b16 B=16, M=512, H=16, K=32    |      123.0  |       186.1
      f16 B=16, M=512, H=16, K=64    |      152.1  |       213.7
      b16 B=16, M=512, H=16, K=64    |      169.3  |       213.7
      f16 B=16, M=512, H=16, K=128   |      318.6  |       367.5
      b16 B=16, M=512, H=16, K=128   |      327.5  |       364.1
      f16 B=16, M=1024, H=16, K=16   |      291.8  |       686.6
      b16 B=16, M=1024, H=16, K=16   |      384.4  |       689.4
      f16 B=16, M=1024, H=16, K=32   |      369.7  |       693.3
      b16 B=16, M=1024, H=16, K=32   |      427.3  |       695.3
      f16 B=16, M=1024, H=16, K=64   |      518.2  |       802.4
      b16 B=16, M=1024, H=16, K=64   |      579.3  |       794.1
      f16 B=16, M=1024, H=16, K=128  |     1098.6  |      1383.8
      b16 B=16, M=1024, H=16, K=128  |     1135.7  |      1371.4
      f16 B=64, M=128, H=16, K=16    |       85.6  |        53.8
      b16 B=64, M=128, H=16, K=16    |       85.7  |        53.8
      f16 B=64, M=128, H=16, K=32    |       89.5  |        58.3
      b16 B=64, M=128, H=16, K=32    |       87.5  |        58.4
      f16 B=64, M=128, H=16, K=64    |       86.1  |        69.5
      b16 B=64, M=128, H=16, K=64    |       85.6  |        69.3
      f16 B=64, M=128, H=16, K=128   |      127.2  |       121.0
      b16 B=64, M=128, H=16, K=128   |      129.7  |       119.1
      f16 B=64, M=512, H=16, K=16    |      307.1  |       703.0
      b16 B=64, M=512, H=16, K=16    |      403.4  |       703.5
      f16 B=64, M=512, H=16, K=32    |      387.1  |       712.1
      b16 B=64, M=512, H=16, K=32    |      449.8  |       711.7
      f16 B=64, M=512, H=16, K=64    |      558.4  |       821.8
      b16 B=64, M=512, H=16, K=64    |      619.9  |       818.0
      f16 B=64, M=512, H=16, K=128   |     1200.5  |      1451.8
      b16 B=64, M=512, H=16, K=128   |     1235.0  |      1434.0
      f16 B=64, M=1024, H=16, K=16   |     1101.1  |      2692.5
      b16 B=64, M=1024, H=16, K=16   |     1472.9  |      2691.6
      f16 B=64, M=1024, H=16, K=32   |     1409.8  |      2720.0
      b16 B=64, M=1024, H=16, K=32   |     1669.5  |      2715.8
      f16 B=64, M=1024, H=16, K=64   |     2025.4  |      3136.1
      b16 B=64, M=1024, H=16, K=64   |     2263.5  |      3105.2
      f16 B=64, M=1024, H=16, K=128  |     4279.5  |      5500.2
      b16 B=64, M=1024, H=16, K=128  |     4413.0  |      5406.7

Times are in microseconds (us).
[ attention (attn_bias=<class 'xformers.ops.LowerTriangularMask'>) ]
                                     |  optimized  |  fctls_bflsh
1 threads: ------------------------------------------------------
      f16 B=384, M=197, H=1, K=64    |      89.1   |       66.3
      b16 B=384, M=197, H=1, K=64    |      87.5   |       66.3
      f16 B=1024, M=197, H=1, K=64   |     187.0   |      153.3
      b16 B=1024, M=197, H=1, K=64   |     196.9   |      153.2
      f16 B=32, M=197, H=16, K=64    |     103.6   |       84.7
      b16 B=32, M=197, H=16, K=64    |     109.5   |       84.8
      f16 B=32, M=197, H=16, K=128   |     157.6   |      126.1
      b16 B=32, M=197, H=16, K=128   |     159.3   |      124.3
      f16 B=16, M=197, H=16, K=64    |      88.5   |       46.3
      b16 B=16, M=197, H=16, K=64    |      86.2   |       46.3
      f16 B=16, M=197, H=16, K=128   |      86.1   |       68.3
      b16 B=16, M=197, H=16, K=128   |      90.7   |       68.3
      f16 B=1, M=4096, H=160, K=128  |    5167.6   |     7921.9
      b16 B=1, M=4096, H=160, K=128  |    5335.3   |     7870.1
      f16 B=2, M=4096, H=160, K=128  |   10226.7   |    15736.6
      b16 B=2, M=4096, H=160, K=128  |   10540.9   |    15639.6
      f16 B=1, M=8192, H=160, K=128  |   19262.0   |    31230.1
      b16 B=1, M=8192, H=160, K=128  |   19908.8   |    30905.6
      f16 B=2, M=8192, H=160, K=128  |   38306.6   |    62229.2
      b16 B=2, M=8192, H=160, K=128  |   39533.3   |    61578.2
      f16 B=1024, M=82, H=8, K=64    |     486.7   |      375.9
      b16 B=1024, M=82, H=8, K=64    |     514.7   |      376.0
      f16 B=150, M=256, H=16, K=64   |     333.8   |      363.1
      b16 B=150, M=256, H=16, K=64   |     349.4   |      359.8
      f16 B=64, M=256, H=12, K=64    |     118.3   |      124.9
      b16 B=64, M=256, H=12, K=64    |     124.5   |      123.9
      f16 B=256, M=4096, H=16, K=64  |   67841.8   |    99189.2
      b16 B=256, M=4096, H=16, K=64  |   73773.6   |    97542.4
      f16 B=8, M=2048, H=20, K=128   |    1487.1   |     1851.1
      b16 B=8, M=2048, H=20, K=128   |    1533.2   |     1818.1
      f16 B=16, M=128, H=16, K=16    |      87.7   |       30.3
      b16 B=16, M=128, H=16, K=16    |      85.4   |       30.6
      f16 B=16, M=128, H=16, K=32    |      86.8   |       30.7
      b16 B=16, M=128, H=16, K=32    |      89.3   |       30.6
      f16 B=16, M=128, H=16, K=64    |      85.6   |       30.8
      b16 B=16, M=128, H=16, K=64    |      88.8   |       30.5
      f16 B=16, M=128, H=16, K=128   |      86.6   |       34.3
      b16 B=16, M=128, H=16, K=128   |      87.6   |       34.4
      f16 B=16, M=512, H=16, K=16    |      86.6   |      119.7
      b16 B=16, M=512, H=16, K=16    |      86.1   |      119.7
      f16 B=16, M=512, H=16, K=32    |      86.2   |      124.0
      b16 B=16, M=512, H=16, K=32    |      98.0   |      124.0
      f16 B=16, M=512, H=16, K=64    |     120.6   |      142.7
      b16 B=16, M=512, H=16, K=64    |     129.1   |      142.8
      f16 B=16, M=512, H=16, K=128   |     225.3   |      241.0
      b16 B=16, M=512, H=16, K=128   |     231.9   |      238.4
      f16 B=16, M=1024, H=16, K=16   |     196.4   |      399.8
      b16 B=16, M=1024, H=16, K=16   |     252.5   |      399.4
      f16 B=16, M=1024, H=16, K=32   |     244.2   |      405.7
      b16 B=16, M=1024, H=16, K=32   |     291.7   |      405.3
      f16 B=16, M=1024, H=16, K=64   |     356.4   |      464.3
      b16 B=16, M=1024, H=16, K=64   |     385.0   |      463.6
      f16 B=16, M=1024, H=16, K=128  |     683.6   |      805.3
      b16 B=16, M=1024, H=16, K=128  |     704.4   |      794.6
      f16 B=64, M=128, H=16, K=16    |      85.7   |       47.1
      b16 B=64, M=128, H=16, K=16    |      86.2   |       47.1
      f16 B=64, M=128, H=16, K=32    |      89.4   |       50.2
      b16 B=64, M=128, H=16, K=32    |      88.1   |       50.2
      f16 B=64, M=128, H=16, K=64    |      85.8   |       61.1
      b16 B=64, M=128, H=16, K=64    |      85.6   |       61.1
      f16 B=64, M=128, H=16, K=128   |     131.1   |      109.6
      b16 B=64, M=128, H=16, K=128   |     133.4   |      108.7
      f16 B=64, M=512, H=16, K=16    |     227.3   |      428.1
      b16 B=64, M=512, H=16, K=16    |     289.1   |      428.1
      f16 B=64, M=512, H=16, K=32    |     285.0   |      434.4
      b16 B=64, M=512, H=16, K=32    |     336.7   |      434.3
      f16 B=64, M=512, H=16, K=64    |     415.5   |      507.7
      b16 B=64, M=512, H=16, K=64    |     444.6   |      502.7
      f16 B=64, M=512, H=16, K=128   |     829.4   |      919.7
      b16 B=64, M=512, H=16, K=128   |     854.1   |      910.8
      f16 B=64, M=1024, H=16, K=16   |     722.3   |     1497.8
      b16 B=64, M=1024, H=16, K=16   |     931.8   |     1495.7
      f16 B=64, M=1024, H=16, K=32   |     900.6   |     1512.9
      b16 B=64, M=1024, H=16, K=32   |    1075.4   |     1513.4
      f16 B=64, M=1024, H=16, K=64   |    1316.9   |     1764.7
      b16 B=64, M=1024, H=16, K=64   |    1424.1   |     1739.0
      f16 B=64, M=1024, H=16, K=128  |    2607.3   |     3139.7
      b16 B=64, M=1024, H=16, K=128  |    2692.0   |     3106.5

Times are in microseconds (us).
[------ attention backward (attn_bias=<class 'NoneType'>) ------]
                                     |  optimized  |  fctls_bflsh
1 threads: ------------------------------------------------------
      f16 B=384, M=197, H=1, K=64    |      545.1  |       236.0
      b16 B=384, M=197, H=1, K=64    |      366.9  |       238.5
      f16 B=1024, M=197, H=1, K=64   |      908.5  |       532.5
      b16 B=1024, M=197, H=1, K=64   |      915.2  |       531.5
      f16 B=32, M=197, H=16, K=64    |      477.8  |       276.7
      b16 B=32, M=197, H=16, K=64    |      481.0  |       277.0
      f16 B=32, M=197, H=16, K=128   |     1307.9  |       634.5
      b16 B=32, M=197, H=16, K=128   |     1156.2  |       635.5
      f16 B=16, M=197, H=16, K=64    |      333.1  |       213.3
      b16 B=16, M=197, H=16, K=64    |      310.0  |       212.3
      f16 B=16, M=197, H=16, K=128   |      597.0  |       365.3
      b16 B=16, M=197, H=16, K=128   |      598.4  |       366.0
      f16 B=1, M=4096, H=160, K=128  |    41711.0  |     54006.5
      b16 B=1, M=4096, H=160, K=128  |    41674.5  |     53975.9
      f16 B=2, M=4096, H=160, K=128  |    82813.6  |     82572.8
      b16 B=2, M=4096, H=160, K=128  |    83121.0  |     82540.1
      f16 B=1, M=8192, H=160, K=128  |   158913.4  |    213065.0
      b16 B=1, M=8192, H=160, K=128  |   159351.2  |    213233.8
      f16 B=2, M=8192, H=160, K=128  |   318080.6  |    324743.8
      b16 B=2, M=8192, H=160, K=128  |   318049.8  |    324995.7
      f16 B=1024, M=82, H=8, K=64    |     2628.2  |      1480.0
      b16 B=1024, M=82, H=8, K=64    |     2514.7  |      1502.1
      f16 B=150, M=256, H=16, K=64   |     2529.2  |      1485.6
      b16 B=150, M=256, H=16, K=64   |     2434.0  |      1485.3
      f16 B=64, M=256, H=12, K=64    |      824.4  |       520.0
      b16 B=64, M=256, H=12, K=64    |      831.4  |       517.1
      f16 B=8, M=2048, H=20, K=128   |    11325.1  |     13894.0
      b16 B=8, M=2048, H=20, K=128   |    11260.2  |     13895.4
      f16 B=16, M=128, H=16, K=16    |      323.3  |       193.0
      b16 B=16, M=128, H=16, K=16    |      339.6  |       213.9
      f16 B=16, M=128, H=16, K=32    |      368.4  |       215.0
      b16 B=16, M=128, H=16, K=32    |      333.9  |       210.6
      f16 B=16, M=128, H=16, K=64    |      328.8  |       196.2
      b16 B=16, M=128, H=16, K=64    |      314.8  |       197.2
      f16 B=16, M=128, H=16, K=128   |      496.4  |       222.7
      b16 B=16, M=128, H=16, K=128   |      334.9  |       215.7
      f16 B=16, M=512, H=16, K=16    |      536.8  |       320.3
      b16 B=16, M=512, H=16, K=16    |      340.9  |       323.6
      f16 B=16, M=512, H=16, K=32    |      660.7  |       422.7
      b16 B=16, M=512, H=16, K=32    |      479.8  |       424.9
      f16 B=16, M=512, H=16, K=64    |      985.6  |       670.5
      b16 B=16, M=512, H=16, K=64    |      824.3  |       672.4
      f16 B=16, M=512, H=16, K=128   |     1864.8  |      1504.7
      b16 B=16, M=512, H=16, K=128   |     1751.5  |      1508.3
      f16 B=16, M=1024, H=16, K=16   |     1255.1  |      1238.7
      b16 B=16, M=1024, H=16, K=16   |     1050.9  |      1241.7
      f16 B=16, M=1024, H=16, K=32   |     1798.1  |      1588.3
      b16 B=16, M=1024, H=16, K=32   |     1610.2  |      1596.3
      f16 B=16, M=1024, H=16, K=64   |     2691.8  |      2304.6
      b16 B=16, M=1024, H=16, K=64   |     2548.6  |      2310.8
      f16 B=16, M=1024, H=16, K=128  |     5376.4  |      5464.7
      b16 B=16, M=1024, H=16, K=128  |     5268.8  |      5473.3
      f16 B=64, M=128, H=16, K=16    |      328.3  |       194.6
      b16 B=64, M=128, H=16, K=16    |      320.7  |       194.9
      f16 B=64, M=128, H=16, K=32    |      317.8  |       221.3
      b16 B=64, M=128, H=16, K=32    |      315.1  |       199.9
      f16 B=64, M=128, H=16, K=64    |      461.0  |       281.6
      b16 B=64, M=128, H=16, K=64    |      465.7  |       284.7
      f16 B=64, M=128, H=16, K=128   |      967.1  |       491.4
      b16 B=64, M=128, H=16, K=128   |      974.3  |       496.1
      f16 B=64, M=512, H=16, K=16    |     1095.1  |      1180.2
      b16 B=64, M=512, H=16, K=16    |     1099.8  |      1188.7
      f16 B=64, M=512, H=16, K=32    |     1657.9  |      1487.9
      b16 B=64, M=512, H=16, K=32    |     1665.4  |      1497.0
      f16 B=64, M=512, H=16, K=64    |     2927.7  |      2291.3
      b16 B=64, M=512, H=16, K=64    |     2949.2  |      2302.8
      f16 B=64, M=512, H=16, K=128   |     6573.7  |      5139.4
      b16 B=64, M=512, H=16, K=128   |     6632.5  |      5169.0
      f16 B=64, M=1024, H=16, K=16   |     3570.0  |      4673.2
      b16 B=64, M=1024, H=16, K=16   |     3582.5  |      4671.1
      f16 B=64, M=1024, H=16, K=32   |     6134.0  |      5590.1
      b16 B=64, M=1024, H=16, K=32   |     6145.1  |      5607.8
      f16 B=64, M=1024, H=16, K=64   |     9832.3  |      7863.1
      b16 B=64, M=1024, H=16, K=64   |     9859.6  |      7890.6
      f16 B=64, M=1024, H=16, K=128  |    20411.4  |     18492.4
      b16 B=64, M=1024, H=16, K=128  |    20521.8  |     18545.9

Times are in microseconds (us).
[ attention backward (attn_bias=<class 'xformers.ops.LowerTriangularMask'>) ]
                                     |  optimized  |  fctls_bflsh
1 threads: ------------------------------------------------------
      f16 B=384, M=197, H=1, K=64    |      542.1  |       241.0
      b16 B=384, M=197, H=1, K=64    |      344.0  |       236.7
      f16 B=1024, M=197, H=1, K=64   |      852.4  |       539.2
      b16 B=1024, M=197, H=1, K=64   |      857.6  |       544.6
      f16 B=32, M=197, H=16, K=64    |      449.0  |       280.6
      b16 B=32, M=197, H=16, K=64    |      451.6  |       283.2
      f16 B=32, M=197, H=16, K=128   |     1161.4  |       517.8
      b16 B=32, M=197, H=16, K=128   |     1030.0  |       518.8
      f16 B=16, M=197, H=16, K=64    |      312.1  |       214.0
      b16 B=16, M=197, H=16, K=64    |      309.8  |       213.7
      f16 B=16, M=197, H=16, K=128   |      547.4  |       296.6
      b16 B=16, M=197, H=16, K=128   |      548.1  |       298.1
      f16 B=1, M=4096, H=160, K=128  |    25831.9  |     31365.0
      b16 B=1, M=4096, H=160, K=128  |    25811.2  |     31382.5
      f16 B=2, M=4096, H=160, K=128  |    51084.6  |     48413.3
      b16 B=2, M=4096, H=160, K=128  |    51265.2  |     48371.8
      f16 B=1, M=8192, H=160, K=128  |    91356.3  |    121976.0
      b16 B=1, M=8192, H=160, K=128  |    91520.0  |    122016.8
      f16 B=2, M=8192, H=160, K=128  |   181984.4  |    187070.0
      b16 B=2, M=8192, H=160, K=128  |   182646.4  |    187302.9
      f16 B=1024, M=82, H=8, K=64    |     2703.5  |      1502.2
      b16 B=1024, M=82, H=8, K=64    |     2615.9  |      1511.7
      f16 B=150, M=256, H=16, K=64   |     2380.2  |      1505.7
      b16 B=150, M=256, H=16, K=64   |     2264.4  |      1521.6
      f16 B=64, M=256, H=12, K=64    |      768.4  |       526.4
      b16 B=64, M=256, H=12, K=64    |      774.2  |       530.1
      f16 B=8, M=2048, H=20, K=128   |     7798.2  |      8299.4
      b16 B=8, M=2048, H=20, K=128   |     7678.0  |      8323.0
      f16 B=16, M=128, H=16, K=16    |      322.9  |       212.5
      b16 B=16, M=128, H=16, K=16    |      311.7  |       192.5
      f16 B=16, M=128, H=16, K=32    |      359.6  |       212.4
      b16 B=16, M=128, H=16, K=32    |      333.6  |       194.5
      f16 B=16, M=128, H=16, K=64    |      330.3  |       196.0
      b16 B=16, M=128, H=16, K=64    |      314.8  |       196.9
      f16 B=16, M=128, H=16, K=128   |      489.5  |       215.7
      b16 B=16, M=128, H=16, K=128   |      343.9  |       215.5
      f16 B=16, M=512, H=16, K=16    |      463.5  |       261.1
      b16 B=16, M=512, H=16, K=16    |      318.1  |       263.3
      f16 B=16, M=512, H=16, K=32    |      573.5  |       337.3
      b16 B=16, M=512, H=16, K=32    |      391.6  |       339.6
      f16 B=16, M=512, H=16, K=64    |      829.6  |       517.0
      b16 B=16, M=512, H=16, K=64    |      666.4  |       521.2
      f16 B=16, M=512, H=16, K=128   |     1608.1  |      1034.4
      b16 B=16, M=512, H=16, K=128   |     1472.1  |      1032.4
      f16 B=16, M=1024, H=16, K=16   |      866.3  |       792.3
      b16 B=16, M=1024, H=16, K=16   |      684.4  |       793.2
      f16 B=16, M=1024, H=16, K=32   |     1269.0  |      1022.8
      b16 B=16, M=1024, H=16, K=32   |     1056.2  |      1023.4
      f16 B=16, M=1024, H=16, K=64   |     1885.2  |      1563.7
      b16 B=16, M=1024, H=16, K=64   |     1749.4  |      1564.6
      f16 B=16, M=1024, H=16, K=128  |     4052.5  |      3428.0
      b16 B=16, M=1024, H=16, K=128  |     3929.8  |      3432.8
      f16 B=64, M=128, H=16, K=16    |      313.7  |       218.0
      b16 B=64, M=128, H=16, K=16    |      336.9  |       217.6
      f16 B=64, M=128, H=16, K=32    |      317.4  |       212.5
      b16 B=64, M=128, H=16, K=32    |      318.1  |       208.0
      f16 B=64, M=128, H=16, K=64    |      470.3  |       284.8
      b16 B=64, M=128, H=16, K=64    |      474.6  |       286.9
      f16 B=64, M=128, H=16, K=128   |     1024.8  |       504.8
      b16 B=64, M=128, H=16, K=128   |     1030.2  |       507.4
      f16 B=64, M=512, H=16, K=16    |      909.5  |       924.7
      b16 B=64, M=512, H=16, K=16    |      913.6  |       930.5
      f16 B=64, M=512, H=16, K=32    |     1454.8  |      1176.4
      b16 B=64, M=512, H=16, K=32    |     1459.1  |      1181.8
      f16 B=64, M=512, H=16, K=64    |     2460.3  |      1752.1
      b16 B=64, M=512, H=16, K=64    |     2485.5  |      1773.0
      f16 B=64, M=512, H=16, K=128   |     5503.4  |      3564.6
      b16 B=64, M=512, H=16, K=128   |     5557.4  |      3592.5
      f16 B=64, M=1024, H=16, K=16   |     2599.1  |      2866.5
      b16 B=64, M=1024, H=16, K=16   |     2605.9  |      2868.9
      f16 B=64, M=1024, H=16, K=32   |     4017.3  |      3577.7
      b16 B=64, M=1024, H=16, K=32   |     4022.4  |      3592.0
      f16 B=64, M=1024, H=16, K=64   |     6648.5  |      5349.7
      b16 B=64, M=1024, H=16, K=64   |     6716.9  |      5374.7
      f16 B=64, M=1024, H=16, K=128  |    15206.5  |     11777.5
      b16 B=64, M=1024, H=16, K=128  |    15313.3  |     11814.0

Times are in microseconds (us).

TODO:

  • [x] get bf16 working
  • [x] get non-causal working
  • [x] benchmarking
  • [x] specify when op can be used
  • [ ] flash bwd, triton fwd
  • [ ] packed
  • [ ] different block sizes for N and M
  • [ ] add Triton Autotune

Before submitting

  • [ ] Did you have fun?
    • Make sure you had fun coding 🙃
  • [ ] Did you read the contributor guideline?
  • [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
    • [ ] N/A
  • [ ] Did you make sure to update the docs?
    • [ ] N/A
  • [ ] Did you write any new necessary tests?
    • [ ] N/A
  • [ ] Did you update the changelog? (if needed)
    • [ ] N/A

PR review

Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

dianaml0 avatar Oct 11 '22 00:10 dianaml0

hey @dianaml0 this is great but you would need a modern triton, right ? I've had a branch up on that for a while, some of the API changed and we need to adapt a lot of the layers. Raising that just in case you bumped into that

blefaudeux avatar Oct 13 '22 15:10 blefaudeux

hey @dianaml0 this is great but you would need a modern triton, right ? I've had a branch up on that for a while, some of the API changed and we need to adapt a lot of the layers. Raising that just in case you bumped into that

Thanks @blefaudeux, that's a good point. I have the updated Triton locally and still facing some errors but still that's needed. Do you think you'll push the changes you have or should I look into it?

dianaml0 avatar Oct 13 '22 20:10 dianaml0

hey @dianaml0 this is great but you would need a modern triton, right ? I've had a branch up on that for a while, some of the API changed and we need to adapt a lot of the layers. Raising that just in case you bumped into that

Thanks @blefaudeux, that's a good point. I have the updated Triton locally and still facing some errors but still that's needed. Do you think you'll push the changes you have or should I look into it?

https://github.com/facebookresearch/xformers/pull/483 should help, it should accept any modern triton pip package !

blefaudeux avatar Oct 15 '22 20:10 blefaudeux

Updated Numbers

Performance Compared to Vanilla FWD
[--------- attention (attn_bias=<class 'NoneType'>) --------]
                                     |  optimized  |  vanilla
1 threads: --------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |     1258.8  |    372.9
      b16 B=384, M=197, H=1, K=88    |     1270.5  |    375.4
      f16 B=384, M=197, H=1, K=80    |      146.8  |    344.5
      b16 B=384, M=197, H=1, K=80    |      149.5  |    346.8
      f16 B=384, M=197, H=1, K=64    |       90.1  |    293.4
      b16 B=384, M=197, H=1, K=64    |       94.3  |    295.4
      f16 B=1024, M=197, H=1, K=88   |     3241.4  |    938.6
      b16 B=1024, M=197, H=1, K=88   |     3263.3  |    945.0
      f16 B=1024, M=197, H=1, K=80   |      349.2  |    865.1
      b16 B=1024, M=197, H=1, K=80   |      354.9  |    871.3
      f16 B=1024, M=197, H=1, K=64   |      213.8  |    729.5
      b16 B=1024, M=197, H=1, K=64   |      222.1  |    735.7
      f16 B=512, M=197, H=1, K=80    |      185.6  |    448.6
      b16 B=512, M=197, H=1, K=80    |      188.8  |    451.6
      f16 B=32, M=197, H=16, K=80    |      193.8  |    548.3
      b16 B=32, M=197, H=16, K=80    |      193.9  |    551.0
      f16 B=32, M=197, H=16, K=64    |      114.4  |    466.7
      b16 B=32, M=197, H=16, K=64    |      120.5  |    469.2
      f16 B=32, M=197, H=16, K=128   |      193.9  |    717.6
      b16 B=32, M=197, H=16, K=128   |      195.0  |    720.7
      f16 B=256, M=197, H=1, K=88    |      868.3  |    260.9
      b16 B=256, M=197, H=1, K=88    |      870.0  |    262.3
      f16 B=16, M=197, H=16, K=88    |      869.9  |    317.9
      b16 B=16, M=197, H=16, K=88    |      870.8  |    319.3
      f16 B=16, M=197, H=16, K=64    |       89.0  |    257.0
      b16 B=16, M=197, H=16, K=64    |       88.2  |    258.2
      f16 B=16, M=197, H=16, K=128   |      101.6  |    385.5
      b16 B=16, M=197, H=16, K=128   |      102.2  |    387.0
      f16 B=1, M=4096, H=160, K=128  |     9824.7  |  20462.2
      b16 B=1, M=4096, H=160, K=128  |    10063.9  |  21439.0
      f16 B=2, M=4096, H=160, K=128  |    19368.3  |  42648.8
      b16 B=2, M=4096, H=160, K=128  |    19877.8  |  44515.5
      f16 B=1, M=8192, H=160, K=128  |    37335.5  |  88470.3
      b16 B=1, M=8192, H=160, K=128  |    38634.4  |  87261.8
      f16 B=2, M=8192, H=160, K=128  |    74047.2  |
      b16 B=2, M=8192, H=160, K=128  |    76836.0  |
      f16 B=1024, M=82, H=8, K=64    |      460.9  |   1769.1
      b16 B=1024, M=82, H=8, K=64    |      480.2  |   1863.2
      f16 B=150, M=256, H=16, K=64   |      383.1  |   1725.0
      b16 B=150, M=256, H=16, K=64   |      421.2  |   1760.5
      f16 B=64, M=256, H=12, K=64    |      131.6  |    587.0
      b16 B=64, M=256, H=12, K=64    |      145.8  |    597.5
      f16 B=1, M=4096, H=16, K=40    |    30691.3  |   1943.9
      b16 B=1, M=4096, H=16, K=40    |    30831.1  |   1998.1
      f16 B=1, M=16384, H=16, K=40   |   422956.0  |  28855.4
      b16 B=1, M=16384, H=16, K=40   |   420924.0  |  30204.7
      f16 B=256, M=4096, H=16, K=64  |   118507.9  |
      b16 B=256, M=4096, H=16, K=64  |   132804.6  |
      f16 B=8, M=2048, H=20, K=128   |     2697.1  |   5700.7
      b16 B=8, M=2048, H=20, K=128   |     2744.9  |   6041.7
      f16 B=16, M=128, H=16, K=16    |       87.3  |    139.1
      b16 B=16, M=128, H=16, K=16    |       87.1  |    137.9
      f16 B=16, M=128, H=16, K=32    |       89.6  |    139.4
      b16 B=16, M=128, H=16, K=32    |       89.2  |    138.3
      f16 B=16, M=128, H=16, K=64    |       86.7  |    137.1
      b16 B=16, M=128, H=16, K=64    |       87.0  |    136.9
      f16 B=16, M=128, H=16, K=128   |       88.7  |    139.6
      b16 B=16, M=128, H=16, K=128   |       86.6  |    140.6
      f16 B=16, M=512, H=16, K=16    |       87.6  |    461.5
      b16 B=16, M=512, H=16, K=16    |       99.5  |    553.9
      f16 B=16, M=512, H=16, K=32    |      105.4  |    512.3
      b16 B=16, M=512, H=16, K=32    |      123.2  |    594.6
      f16 B=16, M=512, H=16, K=64    |      152.1  |    596.1
      b16 B=16, M=512, H=16, K=64    |      169.4  |    616.0
      f16 B=16, M=512, H=16, K=128   |      318.6  |    786.9
      b16 B=16, M=512, H=16, K=128   |      327.6  |    804.4
      f16 B=16, M=1024, H=16, K=16   |      292.0  |   1642.8
      b16 B=16, M=1024, H=16, K=16   |      384.8  |   2028.9
      f16 B=16, M=1024, H=16, K=32   |      369.4  |   1731.9
      b16 B=16, M=1024, H=16, K=32   |      428.1  |   2126.7
      f16 B=16, M=1024, H=16, K=64   |      518.5  |   2034.7
      b16 B=16, M=1024, H=16, K=64   |      579.7  |   2077.2
      f16 B=16, M=1024, H=16, K=128  |     1097.8  |   2421.0
      b16 B=16, M=1024, H=16, K=128  |     1134.9  |   2489.6
      f16 B=64, M=128, H=16, K=16    |       89.3  |    183.7
      b16 B=64, M=128, H=16, K=16    |       87.1  |    185.1
      f16 B=64, M=128, H=16, K=32    |       87.5  |    224.8
      b16 B=64, M=128, H=16, K=32    |       88.1  |    225.5
      f16 B=64, M=128, H=16, K=64    |       86.7  |    321.5
      b16 B=64, M=128, H=16, K=64    |       88.1  |    323.1
      f16 B=64, M=128, H=16, K=128   |      127.1  |    484.8
      b16 B=64, M=128, H=16, K=128   |      129.5  |    486.1
      f16 B=64, M=512, H=16, K=16    |      307.5  |   1727.2
      b16 B=64, M=512, H=16, K=16    |      403.6  |   2094.0
      f16 B=64, M=512, H=16, K=32    |      387.3  |   1893.5
      b16 B=64, M=512, H=16, K=32    |      450.5  |   2250.3
      f16 B=64, M=512, H=16, K=64    |      558.2  |   2234.4
      b16 B=64, M=512, H=16, K=64    |      620.5  |   2294.3
      f16 B=64, M=512, H=16, K=128   |     1201.0  |   3016.3
      b16 B=64, M=512, H=16, K=128   |     1235.3  |   3069.4
      f16 B=64, M=1024, H=16, K=16   |     1101.7  |   6428.2
      b16 B=64, M=1024, H=16, K=16   |     1473.2  |   8038.7
      f16 B=64, M=1024, H=16, K=32   |     1411.1  |   6770.6
      b16 B=64, M=1024, H=16, K=32   |     1667.5  |   8383.2
      f16 B=64, M=1024, H=16, K=64   |     2033.4  |   7978.2
      b16 B=64, M=1024, H=16, K=64   |     2265.9  |   8160.9
      f16 B=64, M=1024, H=16, K=128  |     4281.6  |   9521.1
      b16 B=64, M=1024, H=16, K=128  |     4413.3  |   9886.2

Times are in microseconds (us).
[ attention (attn_bias=<class 'xformers.ops.LowerTriangularMask'>) ]
                                     |  optimized  |  vanilla
1 threads: ---------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |     3697.6  |     446.9
      b16 B=384, M=197, H=1, K=88    |     1178.0  |     452.8
      f16 B=384, M=197, H=1, K=80    |      120.3  |     421.8
      b16 B=384, M=197, H=1, K=80    |      122.6  |     427.6
      f16 B=384, M=197, H=1, K=64    |       87.6  |     370.9
      b16 B=384, M=197, H=1, K=64    |       90.4  |     376.4
      f16 B=1024, M=197, H=1, K=88   |     9588.0  |    1124.2
      b16 B=1024, M=197, H=1, K=88   |     2983.1  |    1140.5
      f16 B=1024, M=197, H=1, K=80   |      294.1  |    1059.7
      b16 B=1024, M=197, H=1, K=80   |      300.2  |    1075.1
      f16 B=1024, M=197, H=1, K=64   |      187.5  |     925.9
      b16 B=1024, M=197, H=1, K=64   |      197.2  |     940.5
      f16 B=512, M=197, H=1, K=80    |      153.7  |     549.6
      b16 B=512, M=197, H=1, K=80    |      156.5  |     557.4
      f16 B=32, M=197, H=16, K=80    |      157.5  |     647.7
      b16 B=32, M=197, H=16, K=80    |      160.8  |     655.5
      f16 B=32, M=197, H=16, K=64    |      103.8  |     565.7
      b16 B=32, M=197, H=16, K=64    |      109.3  |     573.3
      f16 B=32, M=197, H=16, K=128   |      157.6  |     811.7
      b16 B=32, M=197, H=16, K=128   |      159.5  |     819.3
      f16 B=256, M=197, H=1, K=88    |     2552.6  |     311.0
      b16 B=256, M=197, H=1, K=88    |      812.5  |     315.1
      f16 B=16, M=197, H=16, K=88    |     2555.6  |     368.1
      b16 B=16, M=197, H=16, K=88    |      805.8  |     372.4
      f16 B=16, M=197, H=16, K=64    |       87.2  |     307.2
      b16 B=16, M=197, H=16, K=64    |       88.9  |     311.0
      f16 B=16, M=197, H=16, K=128   |       87.7  |     435.1
      b16 B=16, M=197, H=16, K=128   |       87.1  |     438.6
      f16 B=1, M=4096, H=160, K=128  |     5174.9  |   37307.0
      b16 B=1, M=4096, H=160, K=128  |     5338.7  |   37824.6
      f16 B=2, M=4096, H=160, K=128  |    10220.4  |   76294.3
      b16 B=2, M=4096, H=160, K=128  |    10545.0  |   77265.7
      f16 B=1, M=8192, H=160, K=128  |    19334.4  |  152522.4
      b16 B=1, M=8192, H=160, K=128  |    19974.0  |  148646.9
      f16 B=2, M=8192, H=160, K=128  |    38377.2  |
      b16 B=2, M=8192, H=160, K=128  |    39641.7  |
      f16 B=1024, M=82, H=8, K=64    |      488.3  |    1981.4
      b16 B=1024, M=82, H=8, K=64    |      515.8  |    2083.8
      f16 B=150, M=256, H=16, K=64   |      335.1  |    2402.9
      b16 B=150, M=256, H=16, K=64   |      350.9  |    2445.6
      f16 B=64, M=256, H=12, K=64    |      118.9  |     805.2
      b16 B=64, M=256, H=12, K=64    |      124.6  |     819.3
      f16 B=1, M=4096, H=16, K=40    |    14939.0  |    3432.0
      b16 B=1, M=4096, H=16, K=40    |    14915.2  |    3468.0
      f16 B=1, M=16384, H=16, K=40   |   222056.7  |   55382.5
      b16 B=1, M=16384, H=16, K=40   |   218015.9  |   55427.4
      f16 B=256, M=4096, H=16, K=64  |    67733.2  |
      b16 B=256, M=4096, H=16, K=64  |    74031.3  |
      f16 B=8, M=2048, H=20, K=128   |     1488.1  |    9349.8
      b16 B=8, M=2048, H=20, K=128   |     1534.1  |    9538.5
      f16 B=16, M=128, H=16, K=16    |       86.9  |     147.4
      b16 B=16, M=128, H=16, K=16    |       87.9  |     143.6
      f16 B=16, M=128, H=16, K=32    |       86.6  |     144.3
      b16 B=16, M=128, H=16, K=32    |       89.8  |     143.9
      f16 B=16, M=128, H=16, K=64    |       86.4  |     143.1
      b16 B=16, M=128, H=16, K=64    |       86.5  |     141.4
      f16 B=16, M=128, H=16, K=128   |       89.7  |     160.0
      b16 B=16, M=128, H=16, K=128   |       88.3  |     161.9
      f16 B=16, M=512, H=16, K=16    |       90.7  |     715.3
      b16 B=16, M=512, H=16, K=16    |       87.6  |     792.0
      f16 B=16, M=512, H=16, K=32    |       88.5  |     766.8
      b16 B=16, M=512, H=16, K=32    |       98.1  |     832.3
      f16 B=16, M=512, H=16, K=64    |      120.3  |     881.4
      b16 B=16, M=512, H=16, K=64    |      129.1  |     910.7
      f16 B=16, M=512, H=16, K=128   |      225.4  |    1066.2
      b16 B=16, M=512, H=16, K=128   |      232.0  |    1095.8
      f16 B=16, M=1024, H=16, K=16   |      196.6  |    2658.4
      b16 B=16, M=1024, H=16, K=16   |      253.0  |    3042.3
      f16 B=16, M=1024, H=16, K=32   |      244.2  |    2749.6
      b16 B=16, M=1024, H=16, K=32   |      291.8  |    3123.0
      f16 B=16, M=1024, H=16, K=64   |      355.3  |    2962.7
      b16 B=16, M=1024, H=16, K=64   |      384.7  |    3287.6
      f16 B=16, M=1024, H=16, K=128  |      683.0  |    3533.6
      b16 B=16, M=1024, H=16, K=128  |      705.0  |    3714.0
      f16 B=64, M=128, H=16, K=16    |       87.6  |     246.0
      b16 B=64, M=128, H=16, K=16    |       87.2  |     251.8
      f16 B=64, M=128, H=16, K=32    |       87.1  |     294.3
      b16 B=64, M=128, H=16, K=32    |       87.5  |     298.8
      f16 B=64, M=128, H=16, K=64    |       87.8  |     391.4
      b16 B=64, M=128, H=16, K=64    |       86.8  |     396.0
      f16 B=64, M=128, H=16, K=128   |      130.9  |     557.1
      b16 B=64, M=128, H=16, K=128   |      133.1  |     562.3
      f16 B=64, M=512, H=16, K=16    |      227.5  |    2704.1
      b16 B=64, M=512, H=16, K=16    |      288.9  |    3019.5
      f16 B=64, M=512, H=16, K=32    |      285.1  |    2892.2
      b16 B=64, M=512, H=16, K=32    |      336.9  |    3163.7
      f16 B=64, M=512, H=16, K=64    |      414.2  |    3352.4
      b16 B=64, M=512, H=16, K=64    |      444.3  |    3470.6
      f16 B=64, M=512, H=16, K=128   |      828.5  |    4093.5
      b16 B=64, M=512, H=16, K=128   |      854.7  |    4208.4
      f16 B=64, M=1024, H=16, K=16   |      721.9  |   10475.9
      b16 B=64, M=1024, H=16, K=16   |      932.5  |   12026.3
      f16 B=64, M=1024, H=16, K=32   |      900.6  |   10855.0
      b16 B=64, M=1024, H=16, K=32   |     1075.5  |   12354.0
      f16 B=64, M=1024, H=16, K=64   |     1314.2  |   11673.0
      b16 B=64, M=1024, H=16, K=64   |     1423.5  |   13000.3
      f16 B=64, M=1024, H=16, K=128  |     2608.3  |   13961.5
      b16 B=64, M=1024, H=16, K=128  |     2693.0  |   14694.5

Times are in microseconds (us).
Performance Compared to Vanilla BWD
[---- attention backward (attn_bias=<class 'NoneType'>) ----]
                                     |  optimized  |  vanilla
1 threads: --------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |     3679.0  |    820.0
      b16 B=384, M=197, H=1, K=88    |     3560.9  |    823.2
      f16 B=384, M=197, H=1, K=80    |      762.1  |    767.0
      b16 B=384, M=197, H=1, K=80    |      762.6  |    768.6
      f16 B=384, M=197, H=1, K=64    |      543.5  |    651.8
      b16 B=384, M=197, H=1, K=64    |      372.8  |    652.8
      f16 B=1024, M=197, H=1, K=88   |     9127.6  |   2103.8
      b16 B=1024, M=197, H=1, K=88   |     9123.4  |   2104.3
      f16 B=1024, M=197, H=1, K=80   |     1881.4  |   1957.8
      b16 B=1024, M=197, H=1, K=80   |     1889.3  |   1957.7
      f16 B=1024, M=197, H=1, K=64   |      916.7  |   1648.4
      b16 B=1024, M=197, H=1, K=64   |      924.1  |   1649.6
      f16 B=512, M=197, H=1, K=80    |      987.0  |    993.8
      b16 B=512, M=197, H=1, K=80    |      991.6  |    995.8
      f16 B=32, M=197, H=16, K=80    |     1038.9  |   1032.6
      b16 B=32, M=197, H=16, K=80    |     1040.7  |   1035.2
      f16 B=32, M=197, H=16, K=64    |      485.6  |    886.1
      b16 B=32, M=197, H=16, K=64    |      487.9  |    887.0
      f16 B=32, M=197, H=16, K=128   |     1165.6  |   1341.5
      b16 B=32, M=197, H=16, K=128   |     1168.8  |   1344.7
      f16 B=256, M=197, H=1, K=88    |     2364.1  |    576.1
      b16 B=256, M=197, H=1, K=88    |     2363.5  |    577.3
      f16 B=16, M=197, H=16, K=88    |     2378.9  |    599.2
      b16 B=16, M=197, H=16, K=88    |     2377.1  |    600.3
      f16 B=16, M=197, H=16, K=64    |      355.0  |    486.9
      b16 B=16, M=197, H=16, K=64    |      309.7  |    488.6
      f16 B=16, M=197, H=16, K=128   |      617.3  |    713.9
      b16 B=16, M=197, H=16, K=128   |      620.4  |    715.8
      f16 B=1, M=4096, H=160, K=128  |    41943.7  |  38971.0
      b16 B=1, M=4096, H=160, K=128  |    42033.0  |  39938.5
      f16 B=2, M=4096, H=160, K=128  |    83588.1  |  79017.0
      b16 B=2, M=4096, H=160, K=128  |    83792.2  |  80856.7
      f16 B=1, M=8192, H=160, K=128  |   160571.8  |
      b16 B=1, M=8192, H=160, K=128  |   160843.2  |
      f16 B=2, M=8192, H=160, K=128  |   320831.4  |
      b16 B=2, M=8192, H=160, K=128  |   322116.0  |
      f16 B=1024, M=82, H=8, K=64    |     2629.7  |   3609.1
      b16 B=1024, M=82, H=8, K=64    |     2524.3  |   3782.4
      f16 B=150, M=256, H=16, K=64   |     2577.6  |   3790.4
      b16 B=150, M=256, H=16, K=64   |     2458.9  |   3837.4
      f16 B=64, M=256, H=12, K=64    |      827.1  |   1254.5
      b16 B=64, M=256, H=12, K=64    |      835.7  |   1267.6
      f16 B=1, M=4096, H=16, K=40    |    43086.4  |   3562.3
      b16 B=1, M=4096, H=16, K=40    |    43079.7  |   3588.6
      f16 B=1, M=16384, H=16, K=40   |   664389.3  |  53939.0
      b16 B=1, M=16384, H=16, K=40   |   663918.7  |  54291.3
      f16 B=256, M=4096, H=16, K=64  |   504278.4  |
      b16 B=256, M=4096, H=16, K=64  |   503469.4  |
      f16 B=8, M=2048, H=20, K=128   |    11424.6  |  10535.0
      b16 B=8, M=2048, H=20, K=128   |    11365.6  |  10784.8
      f16 B=16, M=128, H=16, K=16    |      351.2  |    348.4
      b16 B=16, M=128, H=16, K=16    |      329.6  |    346.4
      f16 B=16, M=128, H=16, K=32    |      351.3  |    343.9
      b16 B=16, M=128, H=16, K=32    |      321.3  |    319.8
      f16 B=16, M=128, H=16, K=64    |      329.9  |    316.2
      b16 B=16, M=128, H=16, K=64    |      329.8  |    340.4
      f16 B=16, M=128, H=16, K=128   |      453.7  |    338.1
      b16 B=16, M=128, H=16, K=128   |      332.6  |    337.6
      f16 B=16, M=512, H=16, K=16    |      526.4  |    985.4
      b16 B=16, M=512, H=16, K=16    |      322.8  |   1078.1
      f16 B=16, M=512, H=16, K=32    |      683.4  |   1089.3
      b16 B=16, M=512, H=16, K=32    |      486.9  |   1178.8
      f16 B=16, M=512, H=16, K=64    |     1020.1  |   1276.7
      b16 B=16, M=512, H=16, K=64    |      845.3  |   1296.3
      f16 B=16, M=512, H=16, K=128   |     1900.1  |   1715.2
      b16 B=16, M=512, H=16, K=128   |     1766.8  |   1742.1
      f16 B=16, M=1024, H=16, K=16   |     1260.6  |   3525.5
      b16 B=16, M=1024, H=16, K=16   |     1067.4  |   3953.2
      f16 B=16, M=1024, H=16, K=32   |     1794.2  |   3728.4
      b16 B=16, M=1024, H=16, K=32   |     1615.1  |   4153.2
      f16 B=16, M=1024, H=16, K=64   |     2685.4  |   4278.5
      b16 B=16, M=1024, H=16, K=64   |     2545.9  |   4349.5
      f16 B=16, M=1024, H=16, K=128  |     5420.6  |   5108.0
      b16 B=16, M=1024, H=16, K=128  |     5317.9  |   5240.9
      f16 B=64, M=128, H=16, K=16    |      318.0  |    365.1
      b16 B=64, M=128, H=16, K=16    |      333.2  |    373.0
      f16 B=64, M=128, H=16, K=32    |      332.6  |    466.7
      b16 B=64, M=128, H=16, K=32    |      334.3  |    470.8
      f16 B=64, M=128, H=16, K=64    |      464.2  |    680.0
      b16 B=64, M=128, H=16, K=64    |      469.7  |    681.6
      f16 B=64, M=128, H=16, K=128   |      980.2  |   1076.1
      b16 B=64, M=128, H=16, K=128   |      986.8  |   1078.6
      f16 B=64, M=512, H=16, K=16    |     1109.1  |   3706.3
      b16 B=64, M=512, H=16, K=16    |     1112.0  |   4077.9
      f16 B=64, M=512, H=16, K=32    |     1679.8  |   4137.0
      b16 B=64, M=512, H=16, K=32    |     1684.1  |   4507.5
      f16 B=64, M=512, H=16, K=64    |     2988.8  |   4903.0
      b16 B=64, M=512, H=16, K=64    |     3001.1  |   4991.4
      f16 B=64, M=512, H=16, K=128   |     6639.6  |   6643.4
      b16 B=64, M=512, H=16, K=128   |     6698.5  |   6749.8
      f16 B=64, M=1024, H=16, K=16   |     3618.2  |  13888.1
      b16 B=64, M=1024, H=16, K=16   |     3625.5  |  15564.5
      f16 B=64, M=1024, H=16, K=32   |     6140.2  |  14750.0
      b16 B=64, M=1024, H=16, K=32   |     6156.8  |  16452.7
      f16 B=64, M=1024, H=16, K=64   |     9818.7  |  16864.4
      b16 B=64, M=1024, H=16, K=64   |     9843.5  |  17158.3
      f16 B=64, M=1024, H=16, K=128  |    20554.6  |  20403.1
      b16 B=64, M=1024, H=16, K=128  |    20676.6  |  20899.8

Times are in microseconds (us).
[ attention backward (attn_bias=<class 'xformers.ops.LowerTriangularMask'>) ]
                                     |  optimized  |  vanilla
1 threads: --------------------------------------------------
      f16 B=384, M=197, H=1, K=88    |     9482.2  |    821.1
      b16 B=384, M=197, H=1, K=88    |     9334.4  |    823.2
      f16 B=384, M=197, H=1, K=80    |      669.6  |    768.4
      b16 B=384, M=197, H=1, K=80    |      672.2  |    769.8
      f16 B=384, M=197, H=1, K=64    |      544.9  |    651.9
      b16 B=384, M=197, H=1, K=64    |      351.0  |    653.7
      f16 B=1024, M=197, H=1, K=88   |    23569.2  |   2103.4
      b16 B=1024, M=197, H=1, K=88   |    23571.4  |   2105.2
      f16 B=1024, M=197, H=1, K=80   |     1644.3  |   1957.5
      b16 B=1024, M=197, H=1, K=80   |     1649.9  |   1957.4
      f16 B=1024, M=197, H=1, K=64   |      862.7  |   1648.0
      b16 B=1024, M=197, H=1, K=64   |      868.7  |   1649.4
      f16 B=512, M=197, H=1, K=80    |      863.2  |    993.9
      b16 B=512, M=197, H=1, K=80    |      866.3  |    996.2
      f16 B=32, M=197, H=16, K=80    |      867.2  |   1035.0
      b16 B=32, M=197, H=16, K=80    |      869.5  |   1035.5
      f16 B=32, M=197, H=16, K=64    |      456.5  |    884.3
      b16 B=32, M=197, H=16, K=64    |      458.6  |    885.4
      f16 B=32, M=197, H=16, K=128   |     1042.2  |   1339.5
      b16 B=32, M=197, H=16, K=128   |     1046.8  |   1343.0
      f16 B=256, M=197, H=1, K=88    |     6343.3  |    575.9
      b16 B=256, M=197, H=1, K=88    |     6346.2  |    578.0
      f16 B=16, M=197, H=16, K=88    |     6339.5  |    600.1
      b16 B=16, M=197, H=16, K=88    |     6338.8  |    601.4
      f16 B=16, M=197, H=16, K=64    |      323.2  |    488.7
      b16 B=16, M=197, H=16, K=64    |      309.0  |    491.1
      f16 B=16, M=197, H=16, K=128   |      555.8  |    715.2
      b16 B=16, M=197, H=16, K=128   |      557.6  |    717.0
      f16 B=1, M=4096, H=160, K=128  |    26198.1  |  38857.2
      b16 B=1, M=4096, H=160, K=128  |    26165.0  |  39906.6
      f16 B=2, M=4096, H=160, K=128  |    51727.5  |  78346.7
      b16 B=2, M=4096, H=160, K=128  |    51935.9  |  80796.5
      f16 B=1, M=8192, H=160, K=128  |    92706.3  |
      b16 B=1, M=8192, H=160, K=128  |    92957.2  |
      f16 B=2, M=8192, H=160, K=128  |   184853.7  |
      b16 B=2, M=8192, H=160, K=128  |   185265.1  |
      f16 B=1024, M=82, H=8, K=64    |     2746.8  |   3611.4
      b16 B=1024, M=82, H=8, K=64    |     2641.2  |   3784.6
      f16 B=150, M=256, H=16, K=64   |     2381.3  |   3824.6
      b16 B=150, M=256, H=16, K=64   |     2291.5  |   3863.9
      f16 B=64, M=256, H=12, K=64    |      776.2  |   1254.8
      b16 B=64, M=256, H=12, K=64    |      780.0  |   1269.1
      f16 B=1, M=4096, H=16, K=40    |     6826.9  |   3563.2
      b16 B=1, M=4096, H=16, K=40    |     6622.5  |   3596.4
      f16 B=1, M=16384, H=16, K=40   |    94566.7  |  53887.6
      b16 B=1, M=16384, H=16, K=40   |    94648.8  |  54316.2
      f16 B=256, M=4096, H=16, K=64  |   274458.4  |
      b16 B=256, M=4096, H=16, K=64  |   275062.8  |
      f16 B=8, M=2048, H=20, K=128   |     7871.9  |  10544.8
      b16 B=8, M=2048, H=20, K=128   |     7770.2  |  10794.5
      f16 B=16, M=128, H=16, K=16    |      348.7  |    323.2
      b16 B=16, M=128, H=16, K=16    |      326.5  |    325.8
      f16 B=16, M=128, H=16, K=32    |      350.7  |    323.9
      b16 B=16, M=128, H=16, K=32    |      305.9  |    300.9
      f16 B=16, M=128, H=16, K=64    |      361.3  |    328.4
      b16 B=16, M=128, H=16, K=64    |      309.5  |    296.7
      f16 B=16, M=128, H=16, K=128   |      487.7  |    321.4
      b16 B=16, M=128, H=16, K=128   |      329.5  |    305.0
      f16 B=16, M=512, H=16, K=16    |      435.9  |    986.1
      b16 B=16, M=512, H=16, K=16    |      333.1  |   1079.5
      f16 B=16, M=512, H=16, K=32    |      589.3  |   1090.4
      b16 B=16, M=512, H=16, K=32    |      394.0  |   1178.8
      f16 B=16, M=512, H=16, K=64    |      846.2  |   1280.3
      b16 B=16, M=512, H=16, K=64    |      670.0  |   1298.3
      f16 B=16, M=512, H=16, K=128   |     1616.0  |   1712.2
      b16 B=16, M=512, H=16, K=128   |     1485.7  |   1740.0
      f16 B=16, M=1024, H=16, K=16   |      881.3  |   3527.4
      b16 B=16, M=1024, H=16, K=16   |      686.9  |   3948.4
      f16 B=16, M=1024, H=16, K=32   |     1241.7  |   3729.6
      b16 B=16, M=1024, H=16, K=32   |     1059.4  |   4152.4
      f16 B=16, M=1024, H=16, K=64   |     1858.0  |   4284.3
      b16 B=16, M=1024, H=16, K=64   |     1752.9  |   4348.1
      f16 B=16, M=1024, H=16, K=128  |     4083.5  |   5106.9
      b16 B=16, M=1024, H=16, K=128  |     3969.3  |   5246.5
      f16 B=64, M=128, H=16, K=16    |      306.6  |    368.1
      b16 B=64, M=128, H=16, K=16    |      333.1  |    376.4
      f16 B=64, M=128, H=16, K=32    |      327.0  |    466.3
      b16 B=64, M=128, H=16, K=32    |      329.4  |    472.3
      f16 B=64, M=128, H=16, K=64    |      475.4  |    680.6
      b16 B=64, M=128, H=16, K=64    |      479.8  |    680.9
      f16 B=64, M=128, H=16, K=128   |     1041.1  |   1071.8
      b16 B=64, M=128, H=16, K=128   |     1046.3  |   1077.4
      f16 B=64, M=512, H=16, K=16    |      914.1  |   3704.1
      b16 B=64, M=512, H=16, K=16    |      916.3  |   4083.0
      f16 B=64, M=512, H=16, K=32    |     1458.5  |   4134.4
      b16 B=64, M=512, H=16, K=32    |     1461.1  |   4502.1
      f16 B=64, M=512, H=16, K=64    |     2464.9  |   4900.6
      b16 B=64, M=512, H=16, K=64    |     2490.6  |   4991.0
      f16 B=64, M=512, H=16, K=128   |     5549.3  |   6642.1
      b16 B=64, M=512, H=16, K=128   |     5594.7  |   6753.4
      f16 B=64, M=1024, H=16, K=16   |     2605.9  |  13864.0
      b16 B=64, M=1024, H=16, K=16   |     2613.2  |  15564.9
      f16 B=64, M=1024, H=16, K=32   |     4026.7  |  14759.8
      b16 B=64, M=1024, H=16, K=32   |     4029.4  |  16449.0
      f16 B=64, M=1024, H=16, K=64   |     6661.4  |  16899.4
      b16 B=64, M=1024, H=16, K=64   |     6721.7  |  17166.3
      f16 B=64, M=1024, H=16, K=128  |    15343.3  |  20357.1
      b16 B=64, M=1024, H=16, K=128  |    15446.7  |  20917.8

Times are in microseconds (us).

dianaml0 avatar Nov 15 '22 21:11 dianaml0

Thanks a lot for the reviews @fmassa and @danthe3rd ! I've made some changes and added an op for Triton fwd with Flash bwd

dianaml0 avatar Nov 29 '22 00:11 dianaml0

Forwards for Triton fwd and Flash bwd:
[--------- attention (attn_bias=<class 'NoneType'>) --------]
                                     |  optimized  |   eager
1 threads: --------------------------------------------------
      f16 B=384, M=197, H=1, K=64    |       91.7  |    292.3
      b16 B=384, M=197, H=1, K=64    |       94.2  |    295.4
      f16 B=1024, M=197, H=1, K=64   |      213.1  |    727.8
      b16 B=1024, M=197, H=1, K=64   |      221.4  |    734.1
      f16 B=32, M=197, H=16, K=64    |      114.2  |    466.1
      b16 B=32, M=197, H=16, K=64    |      120.6  |    468.4
      f16 B=32, M=197, H=16, K=128   |      194.0  |    716.8
      b16 B=32, M=197, H=16, K=128   |      195.4  |    719.7
      f16 B=16, M=197, H=16, K=64    |       91.5  |    256.3
      b16 B=16, M=197, H=16, K=64    |       90.9  |    256.9
      f16 B=16, M=197, H=16, K=128   |      102.4  |    384.7
      b16 B=16, M=197, H=16, K=128   |      104.6  |    386.0
      f16 B=1, M=4096, H=160, K=128  |     9785.7  |  20529.6
      b16 B=1, M=4096, H=160, K=128  |    10020.5  |  21512.5
      f16 B=2, M=4096, H=160, K=128  |    19305.3  |  42686.7
      b16 B=2, M=4096, H=160, K=128  |    19817.6  |  44503.4
      f16 B=1, M=8192, H=160, K=128  |    37278.0  |  88499.1
      b16 B=1, M=8192, H=160, K=128  |    38494.5  |  87043.0
      f16 B=2, M=8192, H=160, K=128  |    73829.8  |
      b16 B=2, M=8192, H=160, K=128  |    76805.0  |
      f16 B=1024, M=82, H=8, K=64    |      460.8  |   1769.1
      b16 B=1024, M=82, H=8, K=64    |      478.7  |   1864.2
      f16 B=150, M=256, H=16, K=64   |      383.5  |   1727.0
      b16 B=150, M=256, H=16, K=64   |      420.7  |   1761.2
      f16 B=64, M=256, H=12, K=64    |      131.7  |    588.1
      b16 B=64, M=256, H=12, K=64    |      145.6  |    598.2
      f16 B=256, M=4096, H=16, K=64  |   118090.6  |
      b16 B=256, M=4096, H=16, K=64  |   132264.8  |
      f16 B=8, M=2048, H=20, K=128   |     2692.1  |   5746.4
      b16 B=8, M=2048, H=20, K=128   |     2743.3  |   6049.4
      f16 B=16, M=128, H=16, K=16    |       90.0  |    145.5
      b16 B=16, M=128, H=16, K=16    |       92.8  |    143.1
      f16 B=16, M=128, H=16, K=32    |       93.7  |    147.5
      b16 B=16, M=128, H=16, K=32    |       90.5  |    145.4
      f16 B=16, M=128, H=16, K=64    |       90.4  |    146.0
      b16 B=16, M=128, H=16, K=64    |       92.5  |    142.1
      f16 B=16, M=128, H=16, K=128   |       91.0  |    142.3
      b16 B=16, M=128, H=16, K=128   |       90.2  |    143.1
      f16 B=16, M=512, H=16, K=16    |       94.1  |    462.3
      b16 B=16, M=512, H=16, K=16    |      100.9  |    553.7
      f16 B=16, M=512, H=16, K=32    |      105.4  |    513.1
      b16 B=16, M=512, H=16, K=32    |      122.9  |    595.7
      f16 B=16, M=512, H=16, K=64    |      152.1  |    596.8
      b16 B=16, M=512, H=16, K=64    |      169.2  |    613.5
      f16 B=16, M=512, H=16, K=128   |      319.0  |    788.5
      b16 B=16, M=512, H=16, K=128   |      327.9  |    805.1
      f16 B=16, M=1024, H=16, K=16   |      291.8  |   1644.3
      b16 B=16, M=1024, H=16, K=16   |      384.1  |   2026.1
      f16 B=16, M=1024, H=16, K=32   |      369.0  |   1734.6
      b16 B=16, M=1024, H=16, K=32   |      426.7  |   2128.1
      f16 B=16, M=1024, H=16, K=64   |      521.1  |   2035.3
      b16 B=16, M=1024, H=16, K=64   |      578.1  |   2079.1
      f16 B=16, M=1024, H=16, K=128  |     1093.8  |   2420.7
      b16 B=16, M=1024, H=16, K=128  |     1131.1  |   2491.4
      f16 B=64, M=128, H=16, K=16    |       90.9  |    182.8
      b16 B=64, M=128, H=16, K=16    |       90.3  |    184.6
      f16 B=64, M=128, H=16, K=32    |       90.6  |    225.4
      b16 B=64, M=128, H=16, K=32    |       90.3  |    226.7
      f16 B=64, M=128, H=16, K=64    |       90.9  |    324.6
      b16 B=64, M=128, H=16, K=64    |       89.9  |    326.3
      f16 B=64, M=128, H=16, K=128   |      128.4  |    485.0
      b16 B=64, M=128, H=16, K=128   |      130.8  |    486.5
      f16 B=64, M=512, H=16, K=16    |      305.7  |   1729.1
      b16 B=64, M=512, H=16, K=16    |      401.8  |   2094.5
      f16 B=64, M=512, H=16, K=32    |      386.0  |   1899.7
      b16 B=64, M=512, H=16, K=32    |      448.9  |   2252.9
      f16 B=64, M=512, H=16, K=64    |      561.1  |   2246.8
      b16 B=64, M=512, H=16, K=64    |      618.4  |   2302.4
      f16 B=64, M=512, H=16, K=128   |     1195.5  |   3013.4
      b16 B=64, M=512, H=16, K=128   |     1228.0  |   3069.2
      f16 B=64, M=1024, H=16, K=16   |     1097.8  |   6432.0
      b16 B=64, M=1024, H=16, K=16   |     1468.3  |   8046.7
      f16 B=64, M=1024, H=16, K=32   |     1405.7  |   6777.6
      b16 B=64, M=1024, H=16, K=32   |     1662.4  |   8426.7
      f16 B=64, M=1024, H=16, K=64   |     2035.0  |   7999.1
      b16 B=64, M=1024, H=16, K=64   |     2255.6  |   8174.9
      f16 B=64, M=1024, H=16, K=128  |     4269.3  |   9546.4
      b16 B=64, M=1024, H=16, K=128  |     4404.0  |   9904.9

Times are in microseconds (us).

[ attention (attn_bias=<class 'xformers.ops.memory_efficient_attention.LowerTriangularMask'>) ]
                                     |  optimized  |   eager
1 threads: ---------------------------------------------------
      f16 B=384, M=197, H=1, K=64    |      92.4   |     369.8
      b16 B=384, M=197, H=1, K=64    |      94.6   |     376.0
      f16 B=1024, M=197, H=1, K=64   |     186.6   |     923.4
      b16 B=1024, M=197, H=1, K=64   |     196.6   |     938.1
      f16 B=32, M=197, H=16, K=64    |     104.4   |     564.2
      b16 B=32, M=197, H=16, K=64    |     108.8   |     572.2
      f16 B=32, M=197, H=16, K=128   |     157.9   |     812.1
      b16 B=32, M=197, H=16, K=128   |     159.5   |     818.4
      f16 B=16, M=197, H=16, K=64    |      93.2   |     306.1
      b16 B=16, M=197, H=16, K=64    |      90.6   |     310.1
      f16 B=16, M=197, H=16, K=128   |      92.2   |     433.2
      b16 B=16, M=197, H=16, K=128   |      91.6   |     437.6
      f16 B=1, M=4096, H=160, K=128  |    5166.1   |   37295.6
      b16 B=1, M=4096, H=160, K=128  |    5326.0   |   37828.6
      f16 B=2, M=4096, H=160, K=128  |   10191.6   |   76206.0
      b16 B=2, M=4096, H=160, K=128  |   10507.9   |   77209.3
      f16 B=1, M=8192, H=160, K=128  |   19285.9   |  152334.7
      b16 B=1, M=8192, H=160, K=128  |   19893.0   |  148896.7
      f16 B=2, M=8192, H=160, K=128  |   38335.1   |
      b16 B=2, M=8192, H=160, K=128  |   39490.4   |
      f16 B=1024, M=82, H=8, K=64    |     486.3   |    1981.2
      b16 B=1024, M=82, H=8, K=64    |     514.4   |    2086.7
      f16 B=150, M=256, H=16, K=64   |     334.3   |    2402.0
      b16 B=150, M=256, H=16, K=64   |     350.1   |    2447.2
      f16 B=64, M=256, H=12, K=64    |     118.7   |     806.5
      b16 B=64, M=256, H=12, K=64    |     124.9   |     820.5
      f16 B=256, M=4096, H=16, K=64  |   67692.5   |
      b16 B=256, M=4096, H=16, K=64  |   73779.3   |
      f16 B=8, M=2048, H=20, K=128   |    1487.2   |    9354.3
      b16 B=8, M=2048, H=20, K=128   |    1529.4   |    9535.1
      f16 B=16, M=128, H=16, K=16    |      94.6   |     153.5
      b16 B=16, M=128, H=16, K=16    |      93.1   |     150.6
      f16 B=16, M=128, H=16, K=32    |      91.0   |     154.5
      b16 B=16, M=128, H=16, K=32    |      94.2   |     154.0
      f16 B=16, M=128, H=16, K=64    |      93.7   |     149.8
      b16 B=16, M=128, H=16, K=64    |      93.2   |     152.3
      f16 B=16, M=128, H=16, K=128   |      92.0   |     160.4
      b16 B=16, M=128, H=16, K=128   |      91.1   |     162.3
      f16 B=16, M=512, H=16, K=16    |      93.0   |     715.4
      b16 B=16, M=512, H=16, K=16    |      93.9   |     792.7
      f16 B=16, M=512, H=16, K=32    |      91.0   |     767.7
      b16 B=16, M=512, H=16, K=32    |      98.1   |     833.7
      f16 B=16, M=512, H=16, K=64    |     120.8   |     882.7
      b16 B=16, M=512, H=16, K=64    |     129.8   |     911.8
      f16 B=16, M=512, H=16, K=128   |     225.5   |    1067.1
      b16 B=16, M=512, H=16, K=128   |     231.7   |    1096.1
      f16 B=16, M=1024, H=16, K=16   |     195.9   |    2658.6
      b16 B=16, M=1024, H=16, K=16   |     252.9   |    3043.7
      f16 B=16, M=1024, H=16, K=32   |     243.7   |    2752.9
      b16 B=16, M=1024, H=16, K=32   |     291.1   |    3121.6
      f16 B=16, M=1024, H=16, K=64   |     355.5   |    2966.1
      b16 B=16, M=1024, H=16, K=64   |     384.2   |    3287.9
      f16 B=16, M=1024, H=16, K=128  |     682.0   |    3533.0
      b16 B=16, M=1024, H=16, K=128  |     703.3   |    3713.4
      f16 B=64, M=128, H=16, K=16    |      90.4   |     245.3
      b16 B=64, M=128, H=16, K=16    |      90.3   |     251.1
      f16 B=64, M=128, H=16, K=32    |      91.1   |     294.9
      b16 B=64, M=128, H=16, K=32    |      91.8   |     299.4
      f16 B=64, M=128, H=16, K=64    |      90.1   |     393.0
      b16 B=64, M=128, H=16, K=64    |      92.7   |     397.9
      f16 B=64, M=128, H=16, K=128   |     132.3   |     557.3
      b16 B=64, M=128, H=16, K=128   |     134.1   |     562.5
      f16 B=64, M=512, H=16, K=16    |     227.0   |    2705.2
      b16 B=64, M=512, H=16, K=16    |     288.3   |    3017.0
      f16 B=64, M=512, H=16, K=32    |     284.5   |    2893.9
      b16 B=64, M=512, H=16, K=32    |     336.8   |    3167.5
      f16 B=64, M=512, H=16, K=64    |     418.1   |    3338.8
      b16 B=64, M=512, H=16, K=64    |     444.2   |    3460.2
      f16 B=64, M=512, H=16, K=128   |     830.2   |    4091.8
      b16 B=64, M=512, H=16, K=128   |     855.8   |    4207.5
      f16 B=64, M=1024, H=16, K=16   |     720.1   |   10487.1
      b16 B=64, M=1024, H=16, K=16   |     929.6   |   12038.4
      f16 B=64, M=1024, H=16, K=32   |     898.2   |   10852.4
      b16 B=64, M=1024, H=16, K=32   |    1072.3   |   12350.3
      f16 B=64, M=1024, H=16, K=64   |    1325.3   |   11682.5
      b16 B=64, M=1024, H=16, K=64   |    1421.4   |   13010.3
      f16 B=64, M=1024, H=16, K=128  |    2604.2   |   13959.2
      b16 B=64, M=1024, H=16, K=128  |    2683.7   |   14677.4

Times are in microseconds (us).
Backwards for Triton fwd and Flash bwd:
[---- attention backward (attn_bias=<class 'NoneType'>) ----]
                                     |  optimized  |  vanilla
1 threads: --------------------------------------------------
      f16 B=384, M=197, H=1, K=64    |      242.0  |    650.9
      b16 B=384, M=197, H=1, K=64    |      221.9  |    652.5
      f16 B=1024, M=197, H=1, K=64   |      534.5  |   1647.7
      b16 B=1024, M=197, H=1, K=64   |      532.9  |   1648.0
      f16 B=32, M=197, H=16, K=64    |      279.1  |    882.9
      b16 B=32, M=197, H=16, K=64    |      278.5  |    883.9
      f16 B=32, M=197, H=16, K=128   |      638.0  |   1344.0
      b16 B=32, M=197, H=16, K=128   |      640.2  |   1346.5
      f16 B=16, M=197, H=16, K=64    |      235.9  |    484.8
      b16 B=16, M=197, H=16, K=64    |      245.0  |    485.8
      f16 B=16, M=197, H=16, K=128   |      367.5  |    711.5
      b16 B=16, M=197, H=16, K=128   |      368.5  |    713.9
      f16 B=1, M=4096, H=160, K=128  |    53888.1  |  38893.2
      b16 B=1, M=4096, H=160, K=128  |    53989.6  |  39876.1
      f16 B=2, M=4096, H=160, K=128  |    82534.0  |  78629.0
      b16 B=2, M=4096, H=160, K=128  |    82718.4  |  80537.6
      f16 B=1, M=8192, H=160, K=128  |   213238.1  |
      b16 B=1, M=8192, H=160, K=128  |   213335.0  |
      f16 B=2, M=8192, H=160, K=128  |   324974.6  |
      b16 B=2, M=8192, H=160, K=128  |   325014.0  |
      f16 B=1024, M=82, H=8, K=64    |     1494.2  |   3611.5
      b16 B=1024, M=82, H=8, K=64    |     1515.2  |   3787.3
      f16 B=150, M=256, H=16, K=64   |     1493.8  |   3800.9
      b16 B=150, M=256, H=16, K=64   |     1491.3  |   3843.4
      f16 B=64, M=256, H=12, K=64    |      523.3  |   1256.9
      b16 B=64, M=256, H=12, K=64    |      520.9  |   1271.3
      f16 B=256, M=4096, H=16, K=64  |   430986.6  |
      b16 B=256, M=4096, H=16, K=64  |   430693.7  |
      f16 B=8, M=2048, H=20, K=128   |    13945.0  |  10546.8
      b16 B=8, M=2048, H=20, K=128   |    13939.3  |  10798.0
      f16 B=16, M=128, H=16, K=16    |      196.4  |    327.9
      b16 B=16, M=128, H=16, K=16    |      197.0  |    353.1
      f16 B=16, M=128, H=16, K=32    |      215.2  |    350.5
      b16 B=16, M=128, H=16, K=32    |      197.6  |    325.6
      f16 B=16, M=128, H=16, K=64    |      195.9  |    346.2
      b16 B=16, M=128, H=16, K=64    |      216.0  |    357.0
      f16 B=16, M=128, H=16, K=128   |      197.1  |    322.8
      b16 B=16, M=128, H=16, K=128   |      218.6  |    320.6
      f16 B=16, M=512, H=16, K=16    |      321.2  |    984.5
      b16 B=16, M=512, H=16, K=16    |      323.6  |   1077.2
      f16 B=16, M=512, H=16, K=32    |      423.1  |   1089.6
      b16 B=16, M=512, H=16, K=32    |      425.6  |   1178.5
      f16 B=16, M=512, H=16, K=64    |      671.8  |   1285.5
      b16 B=16, M=512, H=16, K=64    |      673.3  |   1306.2
      f16 B=16, M=512, H=16, K=128   |     1512.0  |   1718.0
      b16 B=16, M=512, H=16, K=128   |     1514.4  |   1748.0
      f16 B=16, M=1024, H=16, K=16   |     1237.7  |   3532.8
      b16 B=16, M=1024, H=16, K=16   |     1240.7  |   3957.6
      f16 B=16, M=1024, H=16, K=32   |     1593.6  |   3733.2
      b16 B=16, M=1024, H=16, K=32   |     1594.9  |   4156.9
      f16 B=16, M=1024, H=16, K=64   |     2309.4  |   4278.0
      b16 B=16, M=1024, H=16, K=64   |     2311.7  |   4350.1
      f16 B=16, M=1024, H=16, K=128  |     5478.4  |   5110.4
      b16 B=16, M=1024, H=16, K=128  |     5498.3  |   5251.9
      f16 B=64, M=128, H=16, K=16    |      197.3  |    364.8
      b16 B=64, M=128, H=16, K=16    |      215.9  |    372.4
      f16 B=64, M=128, H=16, K=32    |      200.8  |    464.8
      b16 B=64, M=128, H=16, K=32    |      201.4  |    470.3
      f16 B=64, M=128, H=16, K=64    |      283.5  |    680.8
      b16 B=64, M=128, H=16, K=64    |      286.8  |    682.9
      f16 B=64, M=128, H=16, K=128   |      496.9  |   1076.1
      b16 B=64, M=128, H=16, K=128   |      501.1  |   1078.8
      f16 B=64, M=512, H=16, K=16    |     1181.5  |   3715.9
      b16 B=64, M=512, H=16, K=16    |     1187.2  |   4087.0
      f16 B=64, M=512, H=16, K=32    |     1497.3  |   4141.0
      b16 B=64, M=512, H=16, K=32    |     1505.0  |   4508.7
      f16 B=64, M=512, H=16, K=64    |     2296.0  |   4917.1
      b16 B=64, M=512, H=16, K=64    |     2309.8  |   5008.3
      f16 B=64, M=512, H=16, K=128   |     5157.6  |   6644.4
      b16 B=64, M=512, H=16, K=128   |     5178.1  |   6751.8
      f16 B=64, M=1024, H=16, K=16   |     4668.2  |  13869.2
      b16 B=64, M=1024, H=16, K=16   |     4671.8  |  15561.4
      f16 B=64, M=1024, H=16, K=32   |     5595.4  |  14761.1
      b16 B=64, M=1024, H=16, K=32   |     5609.4  |  16450.7
      f16 B=64, M=1024, H=16, K=64   |     7872.6  |  16891.0
      b16 B=64, M=1024, H=16, K=64   |     7899.7  |  17185.1
      f16 B=64, M=1024, H=16, K=128  |    18538.4  |  20376.2
      b16 B=64, M=1024, H=16, K=128  |    18576.9  |  20931.0

Times are in microseconds (us).

[ attention backward (attn_bias=<class 'xformers.ops.memory_efficient_attention.LowerTriangularMask'>) ]
                                     |  optimized  |  vanilla
1 threads: --------------------------------------------------
      f16 B=384, M=197, H=1, K=64    |      225.1  |    652.1
      b16 B=384, M=197, H=1, K=64    |      227.0  |    652.6
      f16 B=1024, M=197, H=1, K=64   |      542.6  |   1647.2
      b16 B=1024, M=197, H=1, K=64   |      546.3  |   1647.3
      f16 B=32, M=197, H=16, K=64    |      282.4  |    885.0
      b16 B=32, M=197, H=16, K=64    |      285.3  |    886.9
      f16 B=32, M=197, H=16, K=128   |      521.3  |   1345.3
      b16 B=32, M=197, H=16, K=128   |      521.9  |   1346.8
      f16 B=16, M=197, H=16, K=64    |      243.5  |    485.8
      b16 B=16, M=197, H=16, K=64    |      215.6  |    487.3
      f16 B=16, M=197, H=16, K=128   |      298.9  |    714.0
      b16 B=16, M=197, H=16, K=128   |      299.4  |    715.9
      f16 B=1, M=4096, H=160, K=128  |    31532.4  |  38844.2
      b16 B=1, M=4096, H=160, K=128  |    31579.6  |  39971.8
      f16 B=2, M=4096, H=160, K=128  |    48359.2  |  78614.4
      b16 B=2, M=4096, H=160, K=128  |    48379.2  |  80423.5
      f16 B=1, M=8192, H=160, K=128  |   121826.9  |
      b16 B=1, M=8192, H=160, K=128  |   121908.8  |
      f16 B=2, M=8192, H=160, K=128  |   186970.4  |
      b16 B=2, M=8192, H=160, K=128  |   186947.6  |
      f16 B=1024, M=82, H=8, K=64    |     1513.1  |   3612.9
      b16 B=1024, M=82, H=8, K=64    |     1524.6  |   3789.6
      f16 B=150, M=256, H=16, K=64   |     1513.9  |   3821.8
      b16 B=150, M=256, H=16, K=64   |     1530.0  |   3863.2
      f16 B=64, M=256, H=12, K=64    |      529.9  |   1257.3
      b16 B=64, M=256, H=12, K=64    |      533.3  |   1272.5
      f16 B=256, M=4096, H=16, K=64  |   241098.0  |
      b16 B=256, M=4096, H=16, K=64  |   241300.4  |
      f16 B=8, M=2048, H=20, K=128   |     8338.3  |  10544.2
      b16 B=8, M=2048, H=20, K=128   |     8347.7  |  10792.9
      f16 B=16, M=128, H=16, K=16    |      197.3  |    328.7
      b16 B=16, M=128, H=16, K=16    |      215.8  |    329.7
      f16 B=16, M=128, H=16, K=32    |      196.4  |    308.2
      b16 B=16, M=128, H=16, K=32    |      214.5  |    332.4
      f16 B=16, M=128, H=16, K=64    |      221.3  |    302.9
      b16 B=16, M=128, H=16, K=64    |      195.7  |    335.6
      f16 B=16, M=128, H=16, K=128   |      216.6  |    324.5
      b16 B=16, M=128, H=16, K=128   |      217.9  |    321.4
      f16 B=16, M=512, H=16, K=16    |      261.4  |    986.3
      b16 B=16, M=512, H=16, K=16    |      263.7  |   1078.2
      f16 B=16, M=512, H=16, K=32    |      338.6  |   1089.9
      b16 B=16, M=512, H=16, K=32    |      341.3  |   1178.4
      f16 B=16, M=512, H=16, K=64    |      519.7  |   1286.0
      b16 B=16, M=512, H=16, K=64    |      523.9  |   1306.3
      f16 B=16, M=512, H=16, K=128   |     1036.2  |   1717.5
      b16 B=16, M=512, H=16, K=128   |     1039.4  |   1744.6
      f16 B=16, M=1024, H=16, K=16   |      790.6  |   3531.1
      b16 B=16, M=1024, H=16, K=16   |      793.3  |   3954.9
      f16 B=16, M=1024, H=16, K=32   |     1023.9  |   3736.3
      b16 B=16, M=1024, H=16, K=32   |     1023.6  |   4155.7
      f16 B=16, M=1024, H=16, K=64   |     1565.6  |   4279.1
      b16 B=16, M=1024, H=16, K=64   |     1572.5  |   4350.4
      f16 B=16, M=1024, H=16, K=128  |     3437.1  |   5117.6
      b16 B=16, M=1024, H=16, K=128  |     3449.8  |   5254.7
      f16 B=64, M=128, H=16, K=16    |      199.3  |    367.9
      b16 B=64, M=128, H=16, K=16    |      196.3  |    376.1
      f16 B=64, M=128, H=16, K=32    |      206.7  |    465.6
      b16 B=64, M=128, H=16, K=32    |      214.1  |    471.7
      f16 B=64, M=128, H=16, K=64    |      286.8  |    683.6
      b16 B=64, M=128, H=16, K=64    |      289.0  |    685.6
      f16 B=64, M=128, H=16, K=128   |      512.0  |   1075.4
      b16 B=64, M=128, H=16, K=128   |      514.3  |   1078.1
      f16 B=64, M=512, H=16, K=16    |      924.1  |   3713.1
      b16 B=64, M=512, H=16, K=16    |      929.9  |   4090.5
      f16 B=64, M=512, H=16, K=32    |     1180.3  |   4138.9
      b16 B=64, M=512, H=16, K=32    |     1186.6  |   4505.3
      f16 B=64, M=512, H=16, K=64    |     1761.6  |   4914.9
      b16 B=64, M=512, H=16, K=64    |     1780.9  |   5004.2
      f16 B=64, M=512, H=16, K=128   |     3586.7  |   6644.1
      b16 B=64, M=512, H=16, K=128   |     3601.9  |   6750.4
      f16 B=64, M=1024, H=16, K=16   |     2859.7  |  13877.5
      b16 B=64, M=1024, H=16, K=16   |     2863.2  |  15563.7
      f16 B=64, M=1024, H=16, K=32   |     3587.1  |  14759.8
      b16 B=64, M=1024, H=16, K=32   |     3587.2  |  16462.9
      f16 B=64, M=1024, H=16, K=64   |     5363.7  |  16889.4
      b16 B=64, M=1024, H=16, K=64   |     5384.1  |  17213.8
      f16 B=64, M=1024, H=16, K=128  |    11821.0  |  20391.7
      b16 B=64, M=1024, H=16, K=128  |    11861.3  |  20907.0

Times are in microseconds (us).

dianaml0 avatar Nov 29 '22 01:11 dianaml0

@fmassa Thanks a lot for the helpful comments and for taking another pass! I've updated with related changes. Okay to merge for now?

dianaml0 avatar Dec 06 '22 02:12 dianaml0

Codecov Report

Base: 89.79% // Head: 89.06% // Decreases project coverage by -0.73% :warning:

Coverage data is based on head (5724663) compared to base (71205ec). Patch coverage: 48.86% of modified lines in pull request are covered.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #479      +/-   ##
==========================================
- Coverage   89.79%   89.06%   -0.74%     
==========================================
  Files          80       80              
  Lines        4839     4927      +88     
==========================================
+ Hits         4345     4388      +43     
- Misses        494      539      +45     
Flag Coverage Δ
Python 89.06% <48.86%> (-0.74%) :arrow_down:

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
xformers/info.py 0.00% <ø> (ø)
xformers/ops/__init__.py 82.35% <ø> (ø)
xformers/ops/memory_efficient_attention.py 78.05% <48.86%> (-6.66%) :arrow_down:

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

:umbrella: View full report at Codecov.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.

codecov-commenter avatar Dec 06 '22 03:12 codecov-commenter