iree
iree copied to clipboard
Add `LinalgFusionInterface` to support fusion for linalg_ext ops (added `scatter` and `reverse`)
LinalgFusionOpInterface allows for fusion of both Linalg and LinalgExt operations. The new interface provides access to methods essential for performing fusion, allowing existing fusion logic to be used with LinalgExt operations.
As noted in #17392, it probably makes sense to move this into the TilingInterface + probably make it a bit more abstracted
Changes
LinalgFusionOpInterface: A interface that defines methods for fusion operations.- Inherits from
DestinationStyleOpInterface - Implements methods to access indexing maps
- Inherits from
- Implementation for Linalg Ops: The interface is implemented for standard Linalg operations by forwarding to preexisting methods (e.g
getIndexingMaps()). No changes to the ops themselves. - Implementation for LinalgExt Ops: The interface currently only implemented for
iree_linalg_ext.scatter.
Abbreviated Benchmark Summary
@ commit d207b849ce48d3b3de1824127bb390f4c0ac19c6 (vs. base 3a2617f9a8ed59aa3a3e7152b5a4e994a5923226)
Data-Tiling Comparison Table
Click to show
| Name | No-DT (baseline) | DT-Only | DT-UK |
|---|---|---|---|
| BertForMaskedLMTF(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[30-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] | 223.921 (1.0X) | N/A | 105.841 (2.1X) |
| BertLargeTF(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[30-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] | 740.111 (1.0X) | N/A | 223.961 (3.3X) |
| DeepLabV3_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] | 6.968 (1.0X) | N/A | 8.517 (0.8X) |
| DeepLabV3_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] | 31.831 (1.0X) | N/A | 29.797 (1.1X) |
| EfficientNetV2STF(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] | 35.757 (1.0X) | N/A | 34.437 (1.0X) |
| EfficientNetV2STF(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] | 274.706 (1.0X) | N/A | 237.007 (1.2X) |
| EfficientNet_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] | 5.840 (1.0X) | N/A | 5.040 (1.2X) |
| EfficientNet_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] | 26.747 (1.0X) | N/A | 13.343 (2.0X) |
| GPT2_117M_TF_1X1XI32(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] | 9.192 (1.0X) | N/A | 8.788 (1.0X) |
| GPT2_117M_TF_1X1XI32(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] | 70.212 (1.0X) | N/A | 38.118 (1.8X) |
| GPT2_117M_TF_1X4XI32(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] | 11.186 (1.0X) | N/A | 8.577 (1.3X) |
| GPT2_117M_TF_1X4XI32(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] | 89.554 (1.0X) | N/A | 39.434 (2.3X) |
| MiniLML12H384Uncased(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] | 12.269 (1.0X) | N/A | 13.059 (0.9X) |
| MiniLML12H384Uncased(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] | 79.612 (1.0X) | N/A | 56.535 (1.4X) |
| MobileBertSquad_fp16(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] | 34.183 (1.0X) | N/A | 62.187 (0.5X) |
| MobileBertSquad_fp16(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] | 179.760 (1.0X) | N/A | 185.172 (1.0X) |
| MobileBertSquad_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] | 34.952 (1.0X) | N/A | 62.200 (0.6X) |
| MobileBertSquad_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] | 181.685 (1.0X) | N/A | 188.238 (1.0X) |
| MobileBertSquad_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] | 66.348 (1.0X) | N/A | 62.848 (1.1X) |
| MobileBertSquad_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] | 481.880 (1.0X) | N/A | 207.299 (2.3X) |
| MobileNetV1_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] | 4.956 (1.0X) | N/A | 4.567 (1.1X) |
| MobileNetV1_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] | 26.181 (1.0X) | N/A | 17.719 (1.5X) |
| MobileNetV2_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] | 3.772 (1.0X) | N/A | 4.924 (0.8X) |
| MobileNetV2_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] | 11.565 (1.0X) | N/A | 11.899 (1.0X) |
| MobileNetV2_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] | 5.902 (1.0X) | N/A | 5.399 (1.1X) |
| MobileNetV2_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] | 21.566 (1.0X) | N/A | 11.874 (1.8X) |
| MobileNetV3Small_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] | 2.927 (1.0X) | N/A | 2.809 (1.0X) |
| MobileNetV3Small_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] | 2.761 (1.0X) | N/A | 2.680 (1.0X) |
| MobileSSD_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] | 8.485 (1.0X) | N/A | 9.877 (0.9X) |
| MobileSSD_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] | 34.246 (1.0X) | N/A | 31.394 (1.1X) |
| PersonDetect_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] | 0.822 (1.0X) | N/A | 0.608 (1.4X) |
| PersonDetect_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] | 0.748 (1.0X) | N/A | 0.542 (1.4X) |
| PoseNet_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] | 4.170 (1.0X) | N/A | 5.151 (0.8X) |
| PoseNet_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] | 17.473 (1.0X) | N/A | 19.120 (0.9X) |
| matmul_256x256x2048_i8_i4_i32_tile_config_default(linalg) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] | 7.588 (1.0X) | N/A | 7.584 (1.0X) |
| DeepLabV3_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] | 48.612 (1.0X) | N/A | 43.706 (1.1X) |
| DeepLabV3_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] | 50.537 (1.0X) | N/A | 44.633 (1.1X) |
| DeepLabV3_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] | 30.547 (1.0X) | N/A | 28.080 (1.1X) |
| GPT2_117M_TF_1X1XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] | 92.488 (1.0X) | N/A | 21.890 (4.2X) |
| GPT2_117M_TF_1X1XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] | 93.557 (1.0X) | N/A | 22.747 (4.1X) |
| GPT2_117M_TF_1X1XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] | 52.230 (1.0X) | N/A | 22.307 (2.3X) |
| GPT2_117M_TF_1X4XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] | 134.396 (1.0X) | N/A | 27.511 (4.9X) |
| GPT2_117M_TF_1X4XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] | 123.593 (1.0X) | N/A | 29.843 (4.1X) |
| GPT2_117M_TF_1X4XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] | 74.549 (1.0X) | N/A | 26.972 (2.8X) |
| MobileBertSquad_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] | 698.205 (1.0X) | N/A | 350.241 (2.0X) |
| MobileBertSquad_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] | 707.551 (1.0X) | N/A | 362.992 (1.9X) |
| MobileBertSquad_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] | 400.715 (1.0X) | N/A | 218.400 (1.8X) |
| MobileBertSquad_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] | 1043.263 (1.0X) | N/A | 249.667 (4.2X) |
| MobileBertSquad_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] | 1044.339 (1.0X) | N/A | 245.687 (4.3X) |
| MobileBertSquad_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] | 543.181 (1.0X) | N/A | 145.617 (3.7X) |
| Vit_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] | 2095.656 (1.0X) | N/A | 309.865 (6.8X) |
| Vit_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] | 2096.607 (1.0X) | N/A | 308.809 (6.8X) |
| Vit_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] | 1131.655 (1.0X) | N/A | 183.412 (6.2X) |
| matmul_256x256x2048_i8_i4_i32_tile_config_default(linalg) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] | 12.119 (1.0X) | N/A | 1.331 (9.1X) |
Regressed Latencies 🚩
| Benchmark Name | Average Latency (ms) | Median Latency (ms) | Latency Standard Deviation (ms) |
|---|---|---|---|
| MobileBertSquad\_fp16(tflite) [arm-valhall-vulkan\_android31-vulkan\_spirv][experimental-flags,fuse-padding,max-concurrency,demote-f32-to-f16] vulkan(none)[full-inference,default-flags] with default @ pixel-6-pro[gpu] | 101.400 (vs. 92.754, 9.32%↑) | 109.422 | 13.103 |
| MobileNetV1\_fp32(tflite) [x86\_64-cascadelake-linux\_gnu-llvm\_cpu][experimental-flags,no-dt] local\_task(embedded\_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] | 26.181 (vs. 24.548, 6.65%↑) | 26.120 | 0.346 |
Improved Latencies 🎉
| Benchmark Name | Average Latency (ms) | Median Latency (ms) | Latency Standard Deviation (ms) |
|---|---|---|---|
| GPT2\_117M\_TF\_1X4XI32(stablehlo) [armv8.2-a-generic-linux\_android29-llvm\_cpu][experimental-flags,no-dt] local\_task(embedded\_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] | 123.593 (vs. 139.576, 11.45%↓) | 123.622 | 0.303 |
| Vit\_int8(tflite) [armv8.2-a-generic-linux\_android29-llvm\_cpu][default-flags,dt-uk] local\_task(embedded\_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] | 183.412 (vs. 195.651, 6.26%↓) | 184.903 | 4.155 |
| MobileBertSquad\_int8(tflite) [armv8.2-a-generic-linux\_android29-llvm\_cpu][default-flags,dt-uk] local\_task(embedded\_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] | 145.617 (vs. 154.850, 5.96%↓) | 146.545 | 3.939 |
No improved or regressed compilation metrics 🏖️
For more information:
Also @IanWood1 maybe have a better title for the PR.