iree icon indicating copy to clipboard operation
iree copied to clipboard

[Codegen] Add support for vectorizing tensor.pack ops with masking.

Open hanhanW opened this issue 1 year ago • 2 comments

It introduces inferVectorSizesFromIR(Value val) method, so every op can infer the input vector size through the chain. This is important for fusion cases (e.g., generic + pack) because it assumed extract_slice op being a producer op. It improves the input vector size inference generally.

hanhanW avatar Mar 01 '24 18:03 hanhanW

Abbreviated Benchmark Summary

@ commit 14047dde9b63497c3bb46bf599a7596f9945efbd (vs. base 573ff1ff02347266ed747dd316cefaeb4c710396)

Data-Tiling Comparison Table

Click to show
Name No-DT (baseline) DT-Only DT-UK
BertForMaskedLMTF(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[30-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 216.140 (1.0X) N/A 106.034 (2.0X)
BertLargePTBatch1(linalg) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[30-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 655.015 (1.0X) N/A 225.109 (2.9X)
BertLargeTF(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[30-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 701.476 (1.0X) N/A 222.570 (3.2X)
DeepLabV3_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 31.596 (1.0X) N/A 33.329 (0.9X)
DeepLabV3_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 6.953 (1.0X) N/A 8.572 (0.8X)
EfficientNetV2STF(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 272.646 (1.0X) N/A 227.781 (1.2X)
EfficientNetV2STF(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 35.064 (1.0X) N/A 33.573 (1.0X)
EfficientNet_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 26.003 (1.0X) N/A 15.566 (1.7X)
EfficientNet_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 5.568 (1.0X) N/A 5.280 (1.1X)
Falcon7bGptqPT(linalg) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[30-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 20400.041 (1.0X) N/A 4141.295 (4.9X)
Falcon7bInt4GptqPT(linalg) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[30-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 20410.537 (1.0X) N/A 4053.921 (5.0X)
GPT2_117M_TF_1X1XI32(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 69.791 (1.0X) N/A 39.266 (1.8X)
GPT2_117M_TF_1X1XI32(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 9.109 (1.0X) N/A 8.545 (1.1X)
GPT2_117M_TF_1X4XI32(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 88.519 (1.0X) N/A 41.098 (2.2X)
GPT2_117M_TF_1X4XI32(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 10.725 (1.0X) N/A 8.020 (1.3X)
MiniLML12H384Uncased(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 81.515 (1.0X) N/A 58.656 (1.4X)
MiniLML12H384Uncased(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 12.266 (1.0X) N/A 12.898 (1.0X)
MobileBertSquad_fp16(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 178.796 (1.0X) N/A 186.773 (1.0X)
MobileBertSquad_fp16(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 33.197 (1.0X) N/A 58.774 (0.6X)
MobileBertSquad_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 181.224 (1.0X) N/A 190.552 (1.0X)
MobileBertSquad_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 33.541 (1.0X) N/A 59.552 (0.6X)
MobileBertSquad_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 481.539 (1.0X) N/A 215.223 (2.2X)
MobileBertSquad_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 60.670 (1.0X) N/A 65.658 (0.9X)
MobileNetV1_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 27.486 (1.0X) N/A 19.845 (1.4X)
MobileNetV1_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 5.482 (1.0X) N/A 5.048 (1.1X)
MobileNetV2_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 11.036 (1.0X) N/A 12.695 (0.9X)
MobileNetV2_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 3.661 (1.0X) N/A 4.932 (0.7X)
MobileNetV2_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 21.098 (1.0X) N/A 14.202 (1.5X)
MobileNetV2_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 5.702 (1.0X) N/A 5.578 (1.0X)
MobileNetV3Small_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] 2.852 (1.0X) N/A 3.217 (0.9X)
MobileNetV3Small_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 2.836 (1.0X) N/A 3.274 (0.9X)
MobileSSD_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 34.354 (1.0X) N/A 32.138 (1.1X)
MobileSSD_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 8.311 (1.0X) N/A 9.468 (0.9X)
PersonDetect_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] 0.734 (1.0X) N/A 0.614 (1.2X)
PersonDetect_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 0.799 (1.0X) N/A 0.674 (1.2X)
PoseNet_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 18.038 (1.0X) N/A 20.885 (0.9X)
PoseNet_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 4.161 (1.0X) N/A 5.290 (0.8X)
matmul_1x256x2048_i8_i4_i32_tile_config_default(linalg) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] 0.065 (1.0X) N/A 0.065 (1.0X)
matmul_1x256x2048_i8_i8_i32_tile_config_default(linalg) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] 0.042 (1.0X) N/A 0.021 (2.0X)
matmul_256x256x2048_i8_i4_i32_tile_config_default(linalg) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] 7.435 (1.0X) N/A 7.441 (1.0X)
matmul_256x256x2048_i8_i8_i32_tile_config_default(linalg) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] 6.742 (1.0X) N/A 1.977 (3.4X)
DeepLabV3_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 51.440 (1.0X) N/A 86.020 (0.6X)
DeepLabV3_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 54.051 (1.0X) N/A 85.419 (0.6X)
DeepLabV3_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 33.669 (1.0X) N/A 53.429 (0.6X)
GPT2_117M_TF_1X1XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 92.225 (1.0X) N/A 22.239 (4.1X)
GPT2_117M_TF_1X1XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 92.689 (1.0X) N/A 22.517 (4.1X)
GPT2_117M_TF_1X1XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 53.116 (1.0X) N/A 22.035 (2.4X)
GPT2_117M_TF_1X4XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 139.417 (1.0X) N/A 28.329 (4.9X)
GPT2_117M_TF_1X4XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 133.055 (1.0X) N/A 30.905 (4.3X)
GPT2_117M_TF_1X4XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 74.650 (1.0X) N/A 26.637 (2.8X)
MobileBertSquad_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 758.776 (1.0X) N/A 452.874 (1.7X)
MobileBertSquad_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 756.923 (1.0X) N/A 405.381 (1.9X)
MobileBertSquad_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 446.125 (1.0X) N/A 246.400 (1.8X)
MobileBertSquad_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 1098.169 (1.0X) N/A 257.919 (4.3X)
MobileBertSquad_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 1102.532 (1.0X) N/A 260.858 (4.2X)
MobileBertSquad_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 598.704 (1.0X) N/A 153.164 (3.9X)
Vit_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 2177.568 (1.0X) N/A 309.368 (7.0X)
Vit_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 2174.275 (1.0X) N/A 313.893 (6.9X)
Vit_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 1227.048 (1.0X) N/A 189.735 (6.5X)
matmul_1x256x2048_i8_i4_i32_tile_config_default(linalg) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 0.081 (1.0X) N/A 0.016 (5.0X)
matmul_1x256x2048_i8_i8_i32_tile_config_default(linalg) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 0.073 (1.0X) N/A 0.017 (4.3X)
matmul_256x256x2048_i8_i4_i32_tile_config_default(linalg) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 12.014 (1.0X) N/A 1.429 (8.4X)
matmul_256x256x2048_i8_i8_i32_tile_config_default(linalg) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 16.546 (1.0X) N/A 1.245 (13.3X)

Regressed Latencies 🚩

Benchmark Name Average Latency (ms) Median Latency (ms) Latency Standard Deviation (ms)
MobileBertSquad\_fp32(tflite) [armv8.2-a-generic-linux\_android29-llvm\_cpu][default-flags,dt-uk] local\_sync(embedded\_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 452.874 (vs. 374.014, 21.08%↑) 451.174 4.913
DeepLabV3\_fp32(tflite) [armv8.2-a-generic-linux\_android29-llvm\_cpu][default-flags,dt-uk] local\_task(embedded\_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 53.429 (vs. 49.226, 8.54%↑) 53.979 2.584
MobileBertSquad\_fp32(tflite) [armv8.2-a-generic-linux\_android29-llvm\_cpu][experimental-flags,no-dt] local\_sync(embedded\_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 758.776 (vs. 714.709, 6.17%↑) 764.421 15.542

[Top 3 out of 5 results showed]

Improved Latencies 🎉

Benchmark Name Average Latency (ms) Median Latency (ms) Latency Standard Deviation (ms)
MobileBertSquad\_int8(tflite) [armv8.2-a-generic-linux\_android29-llvm\_cpu][default-flags,dt-uk] local\_task(embedded\_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 153.164 (vs. 162.861, 5.95%↓) 154.220 4.399

No improved or regressed compilation metrics 🏖️

For more information:

Source Workflow Run

github-actions[bot] avatar Mar 04 '24 23:03 github-actions[bot]

This is an NFC for e2e tests because all the pack ops are decomposed before vectorization.

hanhanW avatar Mar 05 '24 01:03 hanhanW