iree icon indicating copy to clipboard operation
iree copied to clipboard

[CPU] Enable vector flattening for vectors narrower than the vector register

Open hanhanW opened this issue 1 year ago • 2 comments

Different transformations, like pack/unpack vectorization and mmt4d vectorization using i8mm, may lead to multi-dimensional vector operations whose contiguous dimension may not fully utilize the vector register. This PR enables a transformation that flattens those n-D vectors so that the contiguous dimension can fully utilize the vector register.

Given that dropping vector unit dimensions is a special case of vector flattening and vector flattening patterns also include the drop-unit-dimension patterns, I'm adding vector flattening to the LLVMCPUDropVectorUnitDims pass and renaming it to make it more generic.

Co-authored-by: Diego Caballero [email protected]

hanhanW avatar May 30 '24 18:05 hanhanW

Abbreviated Benchmark Summary

@ commit 1c6290be779aadf4b3ff2336b15d4d80d588d237 (vs. base 6c45befbf9efa47e52471e612fedded32f05f0ee)

Data-Tiling Comparison Table

Click to show
Name No-DT (baseline) DT-Only DT-UK
BertLargeTF(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[30-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 748.086 (1.0X) N/A 224.766 (3.3X)
DeepLabV3_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 7.016 (1.0X) N/A 8.538 (0.8X)
EfficientNetV2STF(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 36.061 (1.0X) N/A 34.365 (1.0X)
EfficientNet_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 5.878 (1.0X) N/A 5.142 (1.1X)
GPT2_117M_TF_1X1XI32(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 9.237 (1.0X) N/A 8.551 (1.1X)
GPT2_117M_TF_1X4XI32(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 11.086 (1.0X) N/A 8.929 (1.2X)
MiniLML12H384Uncased(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 11.972 (1.0X) N/A 13.977 (0.9X)
MobileBertSquad_fp16(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 33.675 (1.0X) N/A 62.332 (0.5X)
MobileBertSquad_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 34.565 (1.0X) N/A 62.467 (0.6X)
MobileBertSquad_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 69.758 (1.0X) N/A 66.099 (1.1X)
MobileNetV1_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 4.832 (1.0X) N/A 4.617 (1.0X)
MobileNetV2_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 3.812 (1.0X) N/A 4.939 (0.8X)
MobileNetV2_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 5.841 (1.0X) N/A 5.513 (1.1X)
MobileNetV3Small_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 3.210 (1.0X) N/A 3.143 (1.0X)
MobileSSD_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 8.501 (1.0X) N/A 9.941 (0.9X)
PersonDetect_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 0.774 (1.0X) N/A 0.610 (1.3X)
PoseNet_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 4.241 (1.0X) N/A 5.309 (0.8X)
matmul_256x256x2048_i8_i4_i32_tile_config_default(linalg) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] 7.573 (1.0X) N/A 7.568 (1.0X)
matmul_256x256x2048_i8_i8_i32_tile_config_default(linalg) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] 6.658 (1.0X) N/A 1.805 (3.7X)
BertForMaskedLMTF(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[30-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 231.661 (1.0X) N/A 108.879 (2.1X)
DeepLabV3_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 32.302 (1.0X) N/A 30.196 (1.1X)
EfficientNetV2STF(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 277.204 (1.0X) N/A 231.528 (1.2X)
EfficientNet_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 27.188 (1.0X) N/A 13.977 (1.9X)
GPT2_117M_TF_1X1XI32(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 70.402 (1.0X) N/A 40.286 (1.7X)
GPT2_117M_TF_1X4XI32(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 90.083 (1.0X) N/A 42.126 (2.1X)
MiniLML12H384Uncased(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 79.129 (1.0X) N/A 57.323 (1.4X)
MobileBertSquad_fp16(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 178.728 (1.0X) N/A 186.628 (1.0X)
MobileBertSquad_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 183.208 (1.0X) N/A 191.237 (1.0X)
MobileBertSquad_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 525.332 (1.0X) N/A 232.037 (2.3X)
MobileNetV1_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 25.179 (1.0X) N/A 18.278 (1.4X)
MobileNetV2_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 11.880 (1.0X) N/A 11.714 (1.0X)
MobileNetV2_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 21.439 (1.0X) N/A 12.181 (1.8X)
MobileNetV3Small_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] 3.099 (1.0X) N/A 3.028 (1.0X)
MobileSSD_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 34.978 (1.0X) N/A 31.804 (1.1X)
PersonDetect_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] 0.703 (1.0X) N/A 0.544 (1.3X)
PoseNet_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 18.243 (1.0X) N/A 19.710 (0.9X)
matmul_1x256x2048_i8_i4_i32_tile_config_default(linalg) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] 0.054 (1.0X) N/A 0.054 (1.0X)
matmul_1x256x2048_i8_i8_i32_tile_config_default(linalg) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] 0.043 (1.0X) N/A 0.021 (2.0X)

No improved or regressed benchmarks 🏖️

Regressed Compilation Times 🚩

Benchmark Name Compilation Time (ms)
MobileNetV2\_int8(tflite) [x86\_64-cascadelake-linux\_gnu-llvm\_cpu][experimental-flags,dt-only,compile-stats] 150683 (vs. 53532, 181.48%↑)
MobileNetV2\_int8(tflite) [x86\_64-cascadelake-linux\_gnu-llvm\_cpu][default-flags,dt-uk,compile-stats] 144332 (vs. 62643, 130.40%↑)
PersonDetect\_int8(tflite) [x86\_64-cascadelake-linux\_gnu-llvm\_cpu][default-flags,dt-uk,compile-stats] 43037 (vs. 21294, 102.11%↑)

Regressed Total Dispatch Sizes 🚩

Benchmark Name Total Dispatch Size (bytes)
EfficientNet\_int8(tflite) [x86\_64-cascadelake-linux\_gnu-llvm\_cpu][default-flags,dt-uk,compile-stats] 740544 (vs. 260256, 184.54%↑)
MobileBertSquad\_int8(tflite) [x86\_64-cascadelake-linux\_gnu-llvm\_cpu][default-flags,dt-uk,compile-stats] 9389168 (vs. 3413872, 175.03%↑)
MobileBertSquad\_int8(tflite) [x86\_64-cascadelake-linux\_gnu-llvm\_cpu][experimental-flags,dt-only,compile-stats] 9393328 (vs. 3423776, 174.36%↑)

[Top 3 out of 21 results showed]

Improved Total Dispatch Sizes 🎉

Benchmark Name Total Dispatch Size (bytes)
MobileBertSquad\_int8(tflite) [riscv\_64-generic-linux\_gnu-llvm\_cpu][default-flags,compile-stats] 2701584 (vs. 3543232, 23.75%↓)
matmul\_1x256x2048\_i8\_i8\_i32\_tile\_config\_default(linalg) [x86\_64-cascadelake-linux\_gnu-llvm\_cpu][experimental-flags,dt-only,compile-stats] 2624 (vs. 2976, 11.83%↓)
MobileBertSquad\_int8(tflite) [riscv\_32-generic-linux\_gnu-llvm\_cpu][default-flags,compile-stats] 3387988 (vs. 3828820, 11.51%↓)

[Top 3 out of 6 results showed]

Regressed Total Artifact Sizes 🚩

Benchmark Name Total Artifact Size (bytes)
PersonDetect\_int8(tflite) [x86\_64-cascadelake-linux\_gnu-llvm\_cpu][default-flags,dt-uk,compile-stats] 525573 (vs. 361861, 45.24%↑)
PersonDetect\_int8(tflite) [x86\_64-cascadelake-linux\_gnu-llvm\_cpu][experimental-flags,dt-only,compile-stats] 526725 (vs. 368069, 43.10%↑)
MobileBertSquad\_int8(tflite) [x86\_64-cascadelake-linux\_gnu-llvm\_cpu][default-flags,dt-uk,compile-stats] 35794181 (vs. 29818885, 20.04%↑)

[Top 3 out of 10 results showed]

For more information:

Source Workflow Run

github-actions[bot] avatar May 31 '24 19:05 github-actions[bot]

We are generating bad IRs (19K ops!) for the dispatch, so I'm not going to file an LLVM issue. I identified that there are at least two issues (https://github.com/iree-org/iree/issues/17593 and https://github.com/iree-org/iree/issues/17594); let's revisit it once we have them implemented.

hanhanW avatar Jun 06 '24 21:06 hanhanW

(closing as stale)

benvanik avatar Apr 30 '25 00:04 benvanik