XNNPACK icon indicating copy to clipboard operation
XNNPACK copied to clipboard

QB4W MLAL GEMM Kernels

Open GregoryComer opened this issue 8 months ago • 1 comments

This pull request adds blockwise 4-bit (qb4w) GEMM microkernels targeting ARM Neon via the MLAL instruction family.

Note: This PR includes one commit from https://github.com/google/XNNPACK/pull/6557 (Test generation update for qb4w). I'm putting this PR up for review before that PR merges so that we can start the review process.

Tests and benchmarks were run on Mac M1 Pro and Samsung S22. Benchmark data includes qc4w benchmarks for comparison. Note that blockwise kernels with block_size equal to kc are functionally equivalent to qc4w, thus qc4w provides a reasonable performance comparison. I expect qb4w with bl=256 to be slightly less performant than qc4w due to the slight increase in memory (~4.125 bits/weight vs ~4 bits per weight for qc4w), as well as due to the slight overhead of the block loop.

S22-A510:

AVERAGE of OPS benchmark m n k bl
llm
128
11008 32000
4096 4096
tile_size kernel_type 32 256 qc4w 32 256 qc4w
1x16 neon_mlal_lane 6.67 7.01 7.04 6.70 7.00 7.05
neon_mlal_lane_prfm 6.64 7.03 6.94 6.72 7.00 7.09
neonfp16arith_mlal_lane 6.76 7.05 7.12 6.78 7.05 7.13
neonfp16arith_mlal_lane_prfm 6.48 7.05 7.15 6.48 7.06 7.14
2x16 neon_mlal_lane 8.32 8.93 9.07 8.31 8.96 9.08
neon_mlal_lane_prfm 8.20 8.83 8.90 8.22 8.84 8.92
neonfp16arith_mlal_lane 8.41 9.00 9.12 8.41 9.03 9.11
neonfp16arith_mlal_lane_prfm 8.36 8.95 8.95 8.37 8.96 8.94
3x16 neon_mlal_lane 9.09 9.96 10.15 9.08 9.94 10.09
neon_mlal_lane_prfm 9.06 9.77 9.96 9.03 9.78 9.98
neonfp16arith_mlal_lane 8.88 9.62 10.20 8.90 9.62 10.20
neonfp16arith_mlal_lane_prfm 9.17 9.90 10.01 9.17 9.91 10.02
4x16 neon_mlal_lane 9.66 10.51 10.51 9.63 10.48 10.31
neon_mlal_lane_prfm 9.63 10.42 10.55 9.63 10.47 10.55
neonfp16arith_mlal_lane 9.75 10.56 10.64 9.77 10.57 10.64
neonfp16arith_mlal_lane_prfm 9.76 10.51 10.59 9.76 10.50 10.59
6x16 neon_mlal_lane 7.73 8.55 8.81 7.77 8.53 8.81
neon_mlal_lane_prfm 7.61 8.32 8.27 7.58 8.33 8.32
neonfp16arith_mlal_lane 7.69 8.43 7.93 7.69 8.44 7.94
neonfp16arith_mlal_lane_prfm 7.63 8.38 8.03 7.63 8.38 8.03
AVERAGE of OPS benchmark m n k
mobilenet_v3_small resnet50
3136 3136
16 256
16 64
tile_size kernel_type 16 qc4w 64 qc4w
1x16 neon_mlal_lane 4.50 4.68 6.63 6.67
neon_mlal_lane_prfm 4.53 4.69 6.64 6.65
neonfp16arith_mlal_lane 4.38 4.58 6.51 6.48
neonfp16arith_mlal_lane_prfm 4.37 4.58 6.53 6.51
2x16 neon_mlal_lane 5.17 5.81 8.12 8.37
neon_mlal_lane_prfm 5.15 5.73 8.10 8.21
neonfp16arith_mlal_lane 5.29 5.83 8.11 8.36
neonfp16arith_mlal_lane_prfm 5.26 5.78 8.10 8.19
3x16 neon_mlal_lane 5.68 6.55 8.96 9.37
neon_mlal_lane_prfm 5.68 6.47 8.92 9.26
neonfp16arith_mlal_lane 5.73 6.75 8.47 9.32
neonfp16arith_mlal_lane_prfm 5.85 6.66 8.74 9.04
4x16 neon_mlal_lane 6.20 6.92 9.47 9.64
neon_mlal_lane_prfm 6.21 6.89 9.42 9.66
neonfp16arith_mlal_lane 6.33 7.03 9.20 9.57
neonfp16arith_mlal_lane_prfm 6.34 7.03 9.21 9.56
6x16 neon_mlal_lane 5.36 5.97 7.90 8.23
neon_mlal_lane_prfm 5.28 5.71 7.73 7.84
neonfp16arith_mlal_lane 5.37 5.52 7.60 7.37
neonfp16arith_mlal_lane_prfm 5.34 5.63 7.59 7.48

S22-A710:

AVERAGE of OPS benchmark m n k bl
llm
128
11008 32000
4096 4096
tile_size kernel_type 32 256 qc4w 32 256 qc4w
1x16 neon_mlal_lane 16.56 17.82 18.09 16.37 17.79 18.10
neon_mlal_lane_prfm 17.02 17.66 18.26 17.03 17.70 18.29
neonfp16arith_mlal_lane 16.57 17.80 18.10 16.55 17.80 18.08
neonfp16arith_mlal_lane_prfm 17.03 17.66 18.25 17.01 17.69 18.29
2x16 neon_mlal_lane 16.73 17.58 17.72 16.75 17.58 17.75
neon_mlal_lane_prfm 16.64 17.41 17.76 16.68 17.41 17.77
neonfp16arith_mlal_lane 16.77 17.57 17.72 16.79 17.59 17.75
neonfp16arith_mlal_lane_prfm 16.83 17.61 17.86 16.81 17.63 17.89
3x16 neon_mlal_lane 16.92 17.79 17.91 16.93 17.82 17.94
neon_mlal_lane_prfm 16.89 17.65 17.76 16.92 17.67 17.78
neonfp16arith_mlal_lane 17.01 17.95 17.91 17.02 17.98 17.94
neonfp16arith_mlal_lane_prfm 17.16 17.98 17.84 17.17 18.01 17.87
4x16 neon_mlal_lane 18.22 19.16 19.41 18.24 19.18 19.44
neon_mlal_lane_prfm 18.23 19.24 19.48 18.24 19.25 19.50
neonfp16arith_mlal_lane 18.18 19.17 19.41 18.20 19.20 19.43
neonfp16arith_mlal_lane_prfm 18.23 19.26 19.32 18.24 19.29 19.35
6x16 neon_mlal_lane 16.42 17.62 17.96 16.43 17.62 17.97
neon_mlal_lane_prfm 16.41 17.75 17.97 16.60 17.83 17.98
neonfp16arith_mlal_lane 16.77 17.81 17.62 16.78 17.82 17.63
neonfp16arith_mlal_lane_prfm 16.79 18.05 17.92 16.97 18.14 17.97
AVERAGE of OPS benchmark m n k
mobilenet_v3_small resnet50
3136 3136
16 256
16 64
tile_size kernel_type 16 qc4w 64 qc4w
1x16 neon_mlal_lane 12.70 13.24 17.12 17.26
neon_mlal_lane_prfm 12.82 12.53 16.65 17.26
neonfp16arith_mlal_lane 12.77 13.22 17.15 17.42
neonfp16arith_mlal_lane_prfm 12.60 12.55 16.67 17.28
2x16 neon_mlal_lane 13.18 13.42 16.56 16.65
neon_mlal_lane_prfm 12.96 13.22 16.31 16.68
neonfp16arith_mlal_lane 12.86 13.31 16.42 16.57
neonfp16arith_mlal_lane_prfm 12.99 13.11 16.47 16.67
3x16 neon_mlal_lane 12.93 13.58 16.68 16.89
neon_mlal_lane_prfm 13.18 13.48 16.62 16.76
neonfp16arith_mlal_lane 13.18 13.23 16.99 16.82
neonfp16arith_mlal_lane_prfm 13.19 13.00 16.95 16.75
4x16 neon_mlal_lane 13.50 14.50 17.69 18.06
neon_mlal_lane_prfm 13.47 14.32 17.70 18.12
neonfp16arith_mlal_lane 13.63 14.03 17.77 17.93
neonfp16arith_mlal_lane_prfm 13.74 13.92 17.78 17.90
6x16 neon_mlal_lane 12.55 13.04 16.60 17.02
neon_mlal_lane_prfm 12.53 12.83 16.77 16.99
neonfp16arith_mlal_lane 12.76 12.89 16.88 16.63
neonfp16arith_mlal_lane_prfm 12.96 12.99 17.16 16.93

M1 Pro:

AVERAGE of OPS benchmark n k bl
llm
16 128 4096 11008 32000
1024 1024 1024 4096 4096
tile_size kernel_type 32 256 nan 32 256 nan 32 256 nan 32 256 nan 32 256 nan
1x16 neon_mlal_lane 46.49 36.38 33.20 47.53 36.40 33.46 47.26 36.52 33.81 46.33 36.05 33.16 46.33 36.20 32.76
neon_mlal_lane_prfm 46.51 36.29 33.91 47.05 36.66 33.85 47.59 36.80 33.99 46.19 36.07 32.78 45.87 35.61 32.72
neonfp16arith_mlal_lane 46.79 36.20 34.13 46.99 36.62 34.43 47.75 36.27 34.48 47.14 36.56 33.60 46.66 36.35 33.31
neonfp16arith_mlal_lane_prfm 46.98 36.50 34.06 47.47 36.48 34.40 47.47 36.90 34.26 46.80 36.33 33.60 46.19 36.30 33.20
2x16 neon_mlal_lane 62.10 63.50 62.07 62.59 64.22 63.56 62.82 65.02 63.22 61.86 64.06 63.12 61.93 63.72 62.71
neon_mlal_lane_prfm 61.93 64.13 63.35 62.72 65.03 63.75 62.63 65.20 63.37 61.79 63.31 62.85 61.57 63.33 62.43
neonfp16arith_mlal_lane 62.18 64.35 64.17 62.81 64.96 64.41 63.20 65.27 64.84 62.39 64.58 64.44 62.13 63.96 64.08
neonfp16arith_mlal_lane_prfm 62.30 64.20 64.23 63.23 65.00 64.72 62.93 65.10 64.36 62.54 64.69 64.48 62.23 64.22 63.86
3x16 neon_mlal_lane 68.19 70.97 70.32 69.53 72.57 71.19 69.80 73.03 70.87 68.71 71.96 70.75 68.72 72.19 70.66
neon_mlal_lane_prfm 68.72 72.04 70.42 69.45 72.53 71.89 69.20 72.23 72.09 68.16 72.35 71.09 69.14 71.94 71.19
neonfp16arith_mlal_lane 68.99 71.04 71.31 69.06 72.50 72.37 69.76 72.61 72.25 69.24 72.18 72.19 69.11 72.27 72.02
neonfp16arith_mlal_lane_prfm 69.05 72.09 72.29 69.17 72.77 73.28 69.31 72.54 73.66 68.63 72.57 72.50 69.11 72.41 72.56
4x16 neon_mlal_lane 73.01 76.40 76.34 74.10 78.11 76.02 73.25 77.36 77.21 73.34 77.51 77.73 73.13 77.21 76.30
neon_mlal_lane_prfm 71.94 76.93 75.82 74.12 78.03 77.34 74.06 78.26 78.19 72.60 77.52 76.21 72.75 76.46 76.05
neonfp16arith_mlal_lane 72.88 76.46 77.81 74.07 78.01 78.68 74.15 78.02 78.70 73.61 78.06 78.43 73.79 77.50 78.27
neonfp16arith_mlal_lane_prfm 72.47 77.23 78.14 74.01 78.05 78.68 73.85 78.21 78.74 73.50 77.81 77.88 73.59 77.75 77.42
6x16 neon_mlal_lane 69.36 74.63 73.26 70.49 76.05 75.26 70.75 75.67 74.48 70.09 75.65 75.02 69.83 75.13 75.04
neon_mlal_lane_prfm 69.62 74.85 74.37 70.52 76.02 74.33 70.48 75.68 74.64 70.16 75.61 74.71 69.27 75.10 74.98
neonfp16arith_mlal_lane 68.99 72.90 75.09 69.85 75.53 76.66 69.81 76.01 75.93 69.86 76.03 76.03 69.88 75.83 76.19
neonfp16arith_mlal_lane_prfm 68.77 74.72 75.75 69.95 75.69 75.74 70.03 76.24 76.87 69.61 75.54 75.97 69.93 75.93 76.13
AVERAGE of OPS benchmark n k bl
mobilenet_v3_small resnet50
16 256
16 64
tile_size kernel_type 16 qc4w 64 qc4w
1x16 neon_mlal_lane 31.70 35.02 40.32 41.65
neon_mlal_lane_prfm 31.65 34.76 41.13 42.04
neonfp16arith_mlal_lane 32.45 35.23 40.69 41.57
neonfp16arith_mlal_lane_prfm 32.23 34.91 40.59 41.13
2x16 neon_mlal_lane 43.30 46.53 61.10 61.46
neon_mlal_lane_prfm 43.48 46.57 61.05 61.16
neonfp16arith_mlal_lane 44.51 47.56 61.22 61.96
neonfp16arith_mlal_lane_prfm 44.75 47.74 60.47 61.77
3x16 neon_mlal_lane 46.83 50.96 67.69 68.11
neon_mlal_lane_prfm 46.72 50.33 67.16 67.79
neonfp16arith_mlal_lane 48.29 51.39 68.20 67.86
neonfp16arith_mlal_lane_prfm 48.27 52.23 68.06 69.08
4x16 neon_mlal_lane 49.24 54.32 71.31 71.85
neon_mlal_lane_prfm 49.72 52.65 71.31 72.34
neonfp16arith_mlal_lane 51.00 55.01 71.71 73.42
neonfp16arith_mlal_lane_prfm 50.61 54.88 72.02 73.29
6x16 neon_mlal_lane 48.58 52.95 68.97 72.31
neon_mlal_lane_prfm 48.17 53.05 69.81 72.80
neonfp16arith_mlal_lane 50.60 55.86 72.41 73.35
neonfp16arith_mlal_lane_prfm 50.53 55.72 72.45 73.49

GregoryComer avatar Jun 17 '24 19:06 GregoryComer