XNNPACK
XNNPACK copied to clipboard
QB4W MLAL GEMM Kernels
This pull request adds blockwise 4-bit (qb4w) GEMM microkernels targeting ARM Neon via the MLAL instruction family.
Note: This PR includes one commit from https://github.com/google/XNNPACK/pull/6557 (Test generation update for qb4w). I'm putting this PR up for review before that PR merges so that we can start the review process.
Tests and benchmarks were run on Mac M1 Pro and Samsung S22. Benchmark data includes qc4w benchmarks for comparison. Note that blockwise kernels with block_size equal to kc are functionally equivalent to qc4w, thus qc4w provides a reasonable performance comparison. I expect qb4w with bl=256 to be slightly less performant than qc4w due to the slight increase in memory (~4.125 bits/weight vs ~4 bits per weight for qc4w), as well as due to the slight overhead of the block loop.
S22-A510:
AVERAGE of OPS | benchmark | m | n | k | bl | ||
---|---|---|---|---|---|---|---|
llm | |||||||
128 | |||||||
11008 | 32000 | ||||||
4096 | 4096 | ||||||
tile_size | kernel_type | 32 | 256 | qc4w | 32 | 256 | qc4w |
1x16 | neon_mlal_lane | 6.67 | 7.01 | 7.04 | 6.70 | 7.00 | 7.05 |
neon_mlal_lane_prfm | 6.64 | 7.03 | 6.94 | 6.72 | 7.00 | 7.09 | |
neonfp16arith_mlal_lane | 6.76 | 7.05 | 7.12 | 6.78 | 7.05 | 7.13 | |
neonfp16arith_mlal_lane_prfm | 6.48 | 7.05 | 7.15 | 6.48 | 7.06 | 7.14 | |
2x16 | neon_mlal_lane | 8.32 | 8.93 | 9.07 | 8.31 | 8.96 | 9.08 |
neon_mlal_lane_prfm | 8.20 | 8.83 | 8.90 | 8.22 | 8.84 | 8.92 | |
neonfp16arith_mlal_lane | 8.41 | 9.00 | 9.12 | 8.41 | 9.03 | 9.11 | |
neonfp16arith_mlal_lane_prfm | 8.36 | 8.95 | 8.95 | 8.37 | 8.96 | 8.94 | |
3x16 | neon_mlal_lane | 9.09 | 9.96 | 10.15 | 9.08 | 9.94 | 10.09 |
neon_mlal_lane_prfm | 9.06 | 9.77 | 9.96 | 9.03 | 9.78 | 9.98 | |
neonfp16arith_mlal_lane | 8.88 | 9.62 | 10.20 | 8.90 | 9.62 | 10.20 | |
neonfp16arith_mlal_lane_prfm | 9.17 | 9.90 | 10.01 | 9.17 | 9.91 | 10.02 | |
4x16 | neon_mlal_lane | 9.66 | 10.51 | 10.51 | 9.63 | 10.48 | 10.31 |
neon_mlal_lane_prfm | 9.63 | 10.42 | 10.55 | 9.63 | 10.47 | 10.55 | |
neonfp16arith_mlal_lane | 9.75 | 10.56 | 10.64 | 9.77 | 10.57 | 10.64 | |
neonfp16arith_mlal_lane_prfm | 9.76 | 10.51 | 10.59 | 9.76 | 10.50 | 10.59 | |
6x16 | neon_mlal_lane | 7.73 | 8.55 | 8.81 | 7.77 | 8.53 | 8.81 |
neon_mlal_lane_prfm | 7.61 | 8.32 | 8.27 | 7.58 | 8.33 | 8.32 | |
neonfp16arith_mlal_lane | 7.69 | 8.43 | 7.93 | 7.69 | 8.44 | 7.94 | |
neonfp16arith_mlal_lane_prfm | 7.63 | 8.38 | 8.03 | 7.63 | 8.38 | 8.03 |
AVERAGE of OPS | benchmark | m | n | k | |
---|---|---|---|---|---|
mobilenet_v3_small | resnet50 | ||||
3136 | 3136 | ||||
16 | 256 | ||||
16 | 64 | ||||
tile_size | kernel_type | 16 | qc4w | 64 | qc4w |
1x16 | neon_mlal_lane | 4.50 | 4.68 | 6.63 | 6.67 |
neon_mlal_lane_prfm | 4.53 | 4.69 | 6.64 | 6.65 | |
neonfp16arith_mlal_lane | 4.38 | 4.58 | 6.51 | 6.48 | |
neonfp16arith_mlal_lane_prfm | 4.37 | 4.58 | 6.53 | 6.51 | |
2x16 | neon_mlal_lane | 5.17 | 5.81 | 8.12 | 8.37 |
neon_mlal_lane_prfm | 5.15 | 5.73 | 8.10 | 8.21 | |
neonfp16arith_mlal_lane | 5.29 | 5.83 | 8.11 | 8.36 | |
neonfp16arith_mlal_lane_prfm | 5.26 | 5.78 | 8.10 | 8.19 | |
3x16 | neon_mlal_lane | 5.68 | 6.55 | 8.96 | 9.37 |
neon_mlal_lane_prfm | 5.68 | 6.47 | 8.92 | 9.26 | |
neonfp16arith_mlal_lane | 5.73 | 6.75 | 8.47 | 9.32 | |
neonfp16arith_mlal_lane_prfm | 5.85 | 6.66 | 8.74 | 9.04 | |
4x16 | neon_mlal_lane | 6.20 | 6.92 | 9.47 | 9.64 |
neon_mlal_lane_prfm | 6.21 | 6.89 | 9.42 | 9.66 | |
neonfp16arith_mlal_lane | 6.33 | 7.03 | 9.20 | 9.57 | |
neonfp16arith_mlal_lane_prfm | 6.34 | 7.03 | 9.21 | 9.56 | |
6x16 | neon_mlal_lane | 5.36 | 5.97 | 7.90 | 8.23 |
neon_mlal_lane_prfm | 5.28 | 5.71 | 7.73 | 7.84 | |
neonfp16arith_mlal_lane | 5.37 | 5.52 | 7.60 | 7.37 | |
neonfp16arith_mlal_lane_prfm | 5.34 | 5.63 | 7.59 | 7.48 |
S22-A710:
AVERAGE of OPS | benchmark | m | n | k | bl | ||
---|---|---|---|---|---|---|---|
llm | |||||||
128 | |||||||
11008 | 32000 | ||||||
4096 | 4096 | ||||||
tile_size | kernel_type | 32 | 256 | qc4w | 32 | 256 | qc4w |
1x16 | neon_mlal_lane | 16.56 | 17.82 | 18.09 | 16.37 | 17.79 | 18.10 |
neon_mlal_lane_prfm | 17.02 | 17.66 | 18.26 | 17.03 | 17.70 | 18.29 | |
neonfp16arith_mlal_lane | 16.57 | 17.80 | 18.10 | 16.55 | 17.80 | 18.08 | |
neonfp16arith_mlal_lane_prfm | 17.03 | 17.66 | 18.25 | 17.01 | 17.69 | 18.29 | |
2x16 | neon_mlal_lane | 16.73 | 17.58 | 17.72 | 16.75 | 17.58 | 17.75 |
neon_mlal_lane_prfm | 16.64 | 17.41 | 17.76 | 16.68 | 17.41 | 17.77 | |
neonfp16arith_mlal_lane | 16.77 | 17.57 | 17.72 | 16.79 | 17.59 | 17.75 | |
neonfp16arith_mlal_lane_prfm | 16.83 | 17.61 | 17.86 | 16.81 | 17.63 | 17.89 | |
3x16 | neon_mlal_lane | 16.92 | 17.79 | 17.91 | 16.93 | 17.82 | 17.94 |
neon_mlal_lane_prfm | 16.89 | 17.65 | 17.76 | 16.92 | 17.67 | 17.78 | |
neonfp16arith_mlal_lane | 17.01 | 17.95 | 17.91 | 17.02 | 17.98 | 17.94 | |
neonfp16arith_mlal_lane_prfm | 17.16 | 17.98 | 17.84 | 17.17 | 18.01 | 17.87 | |
4x16 | neon_mlal_lane | 18.22 | 19.16 | 19.41 | 18.24 | 19.18 | 19.44 |
neon_mlal_lane_prfm | 18.23 | 19.24 | 19.48 | 18.24 | 19.25 | 19.50 | |
neonfp16arith_mlal_lane | 18.18 | 19.17 | 19.41 | 18.20 | 19.20 | 19.43 | |
neonfp16arith_mlal_lane_prfm | 18.23 | 19.26 | 19.32 | 18.24 | 19.29 | 19.35 | |
6x16 | neon_mlal_lane | 16.42 | 17.62 | 17.96 | 16.43 | 17.62 | 17.97 |
neon_mlal_lane_prfm | 16.41 | 17.75 | 17.97 | 16.60 | 17.83 | 17.98 | |
neonfp16arith_mlal_lane | 16.77 | 17.81 | 17.62 | 16.78 | 17.82 | 17.63 | |
neonfp16arith_mlal_lane_prfm | 16.79 | 18.05 | 17.92 | 16.97 | 18.14 | 17.97 |
AVERAGE of OPS | benchmark | m | n | k | |
---|---|---|---|---|---|
mobilenet_v3_small | resnet50 | ||||
3136 | 3136 | ||||
16 | 256 | ||||
16 | 64 | ||||
tile_size | kernel_type | 16 | qc4w | 64 | qc4w |
1x16 | neon_mlal_lane | 12.70 | 13.24 | 17.12 | 17.26 |
neon_mlal_lane_prfm | 12.82 | 12.53 | 16.65 | 17.26 | |
neonfp16arith_mlal_lane | 12.77 | 13.22 | 17.15 | 17.42 | |
neonfp16arith_mlal_lane_prfm | 12.60 | 12.55 | 16.67 | 17.28 | |
2x16 | neon_mlal_lane | 13.18 | 13.42 | 16.56 | 16.65 |
neon_mlal_lane_prfm | 12.96 | 13.22 | 16.31 | 16.68 | |
neonfp16arith_mlal_lane | 12.86 | 13.31 | 16.42 | 16.57 | |
neonfp16arith_mlal_lane_prfm | 12.99 | 13.11 | 16.47 | 16.67 | |
3x16 | neon_mlal_lane | 12.93 | 13.58 | 16.68 | 16.89 |
neon_mlal_lane_prfm | 13.18 | 13.48 | 16.62 | 16.76 | |
neonfp16arith_mlal_lane | 13.18 | 13.23 | 16.99 | 16.82 | |
neonfp16arith_mlal_lane_prfm | 13.19 | 13.00 | 16.95 | 16.75 | |
4x16 | neon_mlal_lane | 13.50 | 14.50 | 17.69 | 18.06 |
neon_mlal_lane_prfm | 13.47 | 14.32 | 17.70 | 18.12 | |
neonfp16arith_mlal_lane | 13.63 | 14.03 | 17.77 | 17.93 | |
neonfp16arith_mlal_lane_prfm | 13.74 | 13.92 | 17.78 | 17.90 | |
6x16 | neon_mlal_lane | 12.55 | 13.04 | 16.60 | 17.02 |
neon_mlal_lane_prfm | 12.53 | 12.83 | 16.77 | 16.99 | |
neonfp16arith_mlal_lane | 12.76 | 12.89 | 16.88 | 16.63 | |
neonfp16arith_mlal_lane_prfm | 12.96 | 12.99 | 17.16 | 16.93 |
M1 Pro:
AVERAGE of OPS | benchmark | n | k | bl | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
llm | ||||||||||||||||
16 | 128 | 4096 | 11008 | 32000 | ||||||||||||
1024 | 1024 | 1024 | 4096 | 4096 | ||||||||||||
tile_size | kernel_type | 32 | 256 | nan | 32 | 256 | nan | 32 | 256 | nan | 32 | 256 | nan | 32 | 256 | nan |
1x16 | neon_mlal_lane | 46.49 | 36.38 | 33.20 | 47.53 | 36.40 | 33.46 | 47.26 | 36.52 | 33.81 | 46.33 | 36.05 | 33.16 | 46.33 | 36.20 | 32.76 |
neon_mlal_lane_prfm | 46.51 | 36.29 | 33.91 | 47.05 | 36.66 | 33.85 | 47.59 | 36.80 | 33.99 | 46.19 | 36.07 | 32.78 | 45.87 | 35.61 | 32.72 | |
neonfp16arith_mlal_lane | 46.79 | 36.20 | 34.13 | 46.99 | 36.62 | 34.43 | 47.75 | 36.27 | 34.48 | 47.14 | 36.56 | 33.60 | 46.66 | 36.35 | 33.31 | |
neonfp16arith_mlal_lane_prfm | 46.98 | 36.50 | 34.06 | 47.47 | 36.48 | 34.40 | 47.47 | 36.90 | 34.26 | 46.80 | 36.33 | 33.60 | 46.19 | 36.30 | 33.20 | |
2x16 | neon_mlal_lane | 62.10 | 63.50 | 62.07 | 62.59 | 64.22 | 63.56 | 62.82 | 65.02 | 63.22 | 61.86 | 64.06 | 63.12 | 61.93 | 63.72 | 62.71 |
neon_mlal_lane_prfm | 61.93 | 64.13 | 63.35 | 62.72 | 65.03 | 63.75 | 62.63 | 65.20 | 63.37 | 61.79 | 63.31 | 62.85 | 61.57 | 63.33 | 62.43 | |
neonfp16arith_mlal_lane | 62.18 | 64.35 | 64.17 | 62.81 | 64.96 | 64.41 | 63.20 | 65.27 | 64.84 | 62.39 | 64.58 | 64.44 | 62.13 | 63.96 | 64.08 | |
neonfp16arith_mlal_lane_prfm | 62.30 | 64.20 | 64.23 | 63.23 | 65.00 | 64.72 | 62.93 | 65.10 | 64.36 | 62.54 | 64.69 | 64.48 | 62.23 | 64.22 | 63.86 | |
3x16 | neon_mlal_lane | 68.19 | 70.97 | 70.32 | 69.53 | 72.57 | 71.19 | 69.80 | 73.03 | 70.87 | 68.71 | 71.96 | 70.75 | 68.72 | 72.19 | 70.66 |
neon_mlal_lane_prfm | 68.72 | 72.04 | 70.42 | 69.45 | 72.53 | 71.89 | 69.20 | 72.23 | 72.09 | 68.16 | 72.35 | 71.09 | 69.14 | 71.94 | 71.19 | |
neonfp16arith_mlal_lane | 68.99 | 71.04 | 71.31 | 69.06 | 72.50 | 72.37 | 69.76 | 72.61 | 72.25 | 69.24 | 72.18 | 72.19 | 69.11 | 72.27 | 72.02 | |
neonfp16arith_mlal_lane_prfm | 69.05 | 72.09 | 72.29 | 69.17 | 72.77 | 73.28 | 69.31 | 72.54 | 73.66 | 68.63 | 72.57 | 72.50 | 69.11 | 72.41 | 72.56 | |
4x16 | neon_mlal_lane | 73.01 | 76.40 | 76.34 | 74.10 | 78.11 | 76.02 | 73.25 | 77.36 | 77.21 | 73.34 | 77.51 | 77.73 | 73.13 | 77.21 | 76.30 |
neon_mlal_lane_prfm | 71.94 | 76.93 | 75.82 | 74.12 | 78.03 | 77.34 | 74.06 | 78.26 | 78.19 | 72.60 | 77.52 | 76.21 | 72.75 | 76.46 | 76.05 | |
neonfp16arith_mlal_lane | 72.88 | 76.46 | 77.81 | 74.07 | 78.01 | 78.68 | 74.15 | 78.02 | 78.70 | 73.61 | 78.06 | 78.43 | 73.79 | 77.50 | 78.27 | |
neonfp16arith_mlal_lane_prfm | 72.47 | 77.23 | 78.14 | 74.01 | 78.05 | 78.68 | 73.85 | 78.21 | 78.74 | 73.50 | 77.81 | 77.88 | 73.59 | 77.75 | 77.42 | |
6x16 | neon_mlal_lane | 69.36 | 74.63 | 73.26 | 70.49 | 76.05 | 75.26 | 70.75 | 75.67 | 74.48 | 70.09 | 75.65 | 75.02 | 69.83 | 75.13 | 75.04 |
neon_mlal_lane_prfm | 69.62 | 74.85 | 74.37 | 70.52 | 76.02 | 74.33 | 70.48 | 75.68 | 74.64 | 70.16 | 75.61 | 74.71 | 69.27 | 75.10 | 74.98 | |
neonfp16arith_mlal_lane | 68.99 | 72.90 | 75.09 | 69.85 | 75.53 | 76.66 | 69.81 | 76.01 | 75.93 | 69.86 | 76.03 | 76.03 | 69.88 | 75.83 | 76.19 | |
neonfp16arith_mlal_lane_prfm | 68.77 | 74.72 | 75.75 | 69.95 | 75.69 | 75.74 | 70.03 | 76.24 | 76.87 | 69.61 | 75.54 | 75.97 | 69.93 | 75.93 | 76.13 |
AVERAGE of OPS | benchmark | n | k | bl | |
---|---|---|---|---|---|
mobilenet_v3_small | resnet50 | ||||
16 | 256 | ||||
16 | 64 | ||||
tile_size | kernel_type | 16 | qc4w | 64 | qc4w |
1x16 | neon_mlal_lane | 31.70 | 35.02 | 40.32 | 41.65 |
neon_mlal_lane_prfm | 31.65 | 34.76 | 41.13 | 42.04 | |
neonfp16arith_mlal_lane | 32.45 | 35.23 | 40.69 | 41.57 | |
neonfp16arith_mlal_lane_prfm | 32.23 | 34.91 | 40.59 | 41.13 | |
2x16 | neon_mlal_lane | 43.30 | 46.53 | 61.10 | 61.46 |
neon_mlal_lane_prfm | 43.48 | 46.57 | 61.05 | 61.16 | |
neonfp16arith_mlal_lane | 44.51 | 47.56 | 61.22 | 61.96 | |
neonfp16arith_mlal_lane_prfm | 44.75 | 47.74 | 60.47 | 61.77 | |
3x16 | neon_mlal_lane | 46.83 | 50.96 | 67.69 | 68.11 |
neon_mlal_lane_prfm | 46.72 | 50.33 | 67.16 | 67.79 | |
neonfp16arith_mlal_lane | 48.29 | 51.39 | 68.20 | 67.86 | |
neonfp16arith_mlal_lane_prfm | 48.27 | 52.23 | 68.06 | 69.08 | |
4x16 | neon_mlal_lane | 49.24 | 54.32 | 71.31 | 71.85 |
neon_mlal_lane_prfm | 49.72 | 52.65 | 71.31 | 72.34 | |
neonfp16arith_mlal_lane | 51.00 | 55.01 | 71.71 | 73.42 | |
neonfp16arith_mlal_lane_prfm | 50.61 | 54.88 | 72.02 | 73.29 | |
6x16 | neon_mlal_lane | 48.58 | 52.95 | 68.97 | 72.31 |
neon_mlal_lane_prfm | 48.17 | 53.05 | 69.81 | 72.80 | |
neonfp16arith_mlal_lane | 50.60 | 55.86 | 72.41 | 73.35 | |
neonfp16arith_mlal_lane_prfm | 50.53 | 55.72 | 72.45 | 73.49 |