iree icon indicating copy to clipboard operation
iree copied to clipboard

Add `LinalgFusionInterface` to support fusion for linalg_ext ops (added `scatter` and `reverse`)

Open IanWood1 opened this issue 1 year ago • 2 comments

LinalgFusionOpInterface allows for fusion of both Linalg and LinalgExt operations. The new interface provides access to methods essential for performing fusion, allowing existing fusion logic to be used with LinalgExt operations.

As noted in #17392, it probably makes sense to move this into the TilingInterface + probably make it a bit more abstracted

Changes

  • LinalgFusionOpInterface: A interface that defines methods for fusion operations.
    • Inherits from DestinationStyleOpInterface
    • Implements methods to access indexing maps
  • Implementation for Linalg Ops: The interface is implemented for standard Linalg operations by forwarding to preexisting methods (e.g getIndexingMaps()). No changes to the ops themselves.
  • Implementation for LinalgExt Ops: The interface currently only implemented for iree_linalg_ext.scatter.

IanWood1 avatar May 16 '24 23:05 IanWood1

Abbreviated Benchmark Summary

@ commit d207b849ce48d3b3de1824127bb390f4c0ac19c6 (vs. base 3a2617f9a8ed59aa3a3e7152b5a4e994a5923226)

Data-Tiling Comparison Table

Click to show
Name No-DT (baseline) DT-Only DT-UK
BertForMaskedLMTF(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[30-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 223.921 (1.0X) N/A 105.841 (2.1X)
BertLargeTF(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[30-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 740.111 (1.0X) N/A 223.961 (3.3X)
DeepLabV3_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 6.968 (1.0X) N/A 8.517 (0.8X)
DeepLabV3_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 31.831 (1.0X) N/A 29.797 (1.1X)
EfficientNetV2STF(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 35.757 (1.0X) N/A 34.437 (1.0X)
EfficientNetV2STF(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 274.706 (1.0X) N/A 237.007 (1.2X)
EfficientNet_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 5.840 (1.0X) N/A 5.040 (1.2X)
EfficientNet_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 26.747 (1.0X) N/A 13.343 (2.0X)
GPT2_117M_TF_1X1XI32(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 9.192 (1.0X) N/A 8.788 (1.0X)
GPT2_117M_TF_1X1XI32(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 70.212 (1.0X) N/A 38.118 (1.8X)
GPT2_117M_TF_1X4XI32(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 11.186 (1.0X) N/A 8.577 (1.3X)
GPT2_117M_TF_1X4XI32(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 89.554 (1.0X) N/A 39.434 (2.3X)
MiniLML12H384Uncased(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 12.269 (1.0X) N/A 13.059 (0.9X)
MiniLML12H384Uncased(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 79.612 (1.0X) N/A 56.535 (1.4X)
MobileBertSquad_fp16(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 34.183 (1.0X) N/A 62.187 (0.5X)
MobileBertSquad_fp16(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 179.760 (1.0X) N/A 185.172 (1.0X)
MobileBertSquad_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 34.952 (1.0X) N/A 62.200 (0.6X)
MobileBertSquad_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 181.685 (1.0X) N/A 188.238 (1.0X)
MobileBertSquad_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 66.348 (1.0X) N/A 62.848 (1.1X)
MobileBertSquad_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 481.880 (1.0X) N/A 207.299 (2.3X)
MobileNetV1_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 4.956 (1.0X) N/A 4.567 (1.1X)
MobileNetV1_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 26.181 (1.0X) N/A 17.719 (1.5X)
MobileNetV2_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 3.772 (1.0X) N/A 4.924 (0.8X)
MobileNetV2_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 11.565 (1.0X) N/A 11.899 (1.0X)
MobileNetV2_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 5.902 (1.0X) N/A 5.399 (1.1X)
MobileNetV2_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 21.566 (1.0X) N/A 11.874 (1.8X)
MobileNetV3Small_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 2.927 (1.0X) N/A 2.809 (1.0X)
MobileNetV3Small_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] 2.761 (1.0X) N/A 2.680 (1.0X)
MobileSSD_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 8.485 (1.0X) N/A 9.877 (0.9X)
MobileSSD_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 34.246 (1.0X) N/A 31.394 (1.1X)
PersonDetect_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 0.822 (1.0X) N/A 0.608 (1.4X)
PersonDetect_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] 0.748 (1.0X) N/A 0.542 (1.4X)
PoseNet_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 4.170 (1.0X) N/A 5.151 (0.8X)
PoseNet_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 17.473 (1.0X) N/A 19.120 (0.9X)
matmul_256x256x2048_i8_i4_i32_tile_config_default(linalg) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] 7.588 (1.0X) N/A 7.584 (1.0X)
DeepLabV3_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 48.612 (1.0X) N/A 43.706 (1.1X)
DeepLabV3_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 50.537 (1.0X) N/A 44.633 (1.1X)
DeepLabV3_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 30.547 (1.0X) N/A 28.080 (1.1X)
GPT2_117M_TF_1X1XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 92.488 (1.0X) N/A 21.890 (4.2X)
GPT2_117M_TF_1X1XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 93.557 (1.0X) N/A 22.747 (4.1X)
GPT2_117M_TF_1X1XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 52.230 (1.0X) N/A 22.307 (2.3X)
GPT2_117M_TF_1X4XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 134.396 (1.0X) N/A 27.511 (4.9X)
GPT2_117M_TF_1X4XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 123.593 (1.0X) N/A 29.843 (4.1X)
GPT2_117M_TF_1X4XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 74.549 (1.0X) N/A 26.972 (2.8X)
MobileBertSquad_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 698.205 (1.0X) N/A 350.241 (2.0X)
MobileBertSquad_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 707.551 (1.0X) N/A 362.992 (1.9X)
MobileBertSquad_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 400.715 (1.0X) N/A 218.400 (1.8X)
MobileBertSquad_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 1043.263 (1.0X) N/A 249.667 (4.2X)
MobileBertSquad_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 1044.339 (1.0X) N/A 245.687 (4.3X)
MobileBertSquad_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 543.181 (1.0X) N/A 145.617 (3.7X)
Vit_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 2095.656 (1.0X) N/A 309.865 (6.8X)
Vit_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 2096.607 (1.0X) N/A 308.809 (6.8X)
Vit_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 1131.655 (1.0X) N/A 183.412 (6.2X)
matmul_256x256x2048_i8_i4_i32_tile_config_default(linalg) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 12.119 (1.0X) N/A 1.331 (9.1X)

Regressed Latencies 🚩

Benchmark Name Average Latency (ms) Median Latency (ms) Latency Standard Deviation (ms)
MobileBertSquad\_fp16(tflite) [arm-valhall-vulkan\_android31-vulkan\_spirv][experimental-flags,fuse-padding,max-concurrency,demote-f32-to-f16] vulkan(none)[full-inference,default-flags] with default @ pixel-6-pro[gpu] 101.400 (vs. 92.754, 9.32%↑) 109.422 13.103
MobileNetV1\_fp32(tflite) [x86\_64-cascadelake-linux\_gnu-llvm\_cpu][experimental-flags,no-dt] local\_task(embedded\_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 26.181 (vs. 24.548, 6.65%↑) 26.120 0.346

Improved Latencies 🎉

Benchmark Name Average Latency (ms) Median Latency (ms) Latency Standard Deviation (ms)
GPT2\_117M\_TF\_1X4XI32(stablehlo) [armv8.2-a-generic-linux\_android29-llvm\_cpu][experimental-flags,no-dt] local\_task(embedded\_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 123.593 (vs. 139.576, 11.45%↓) 123.622 0.303
Vit\_int8(tflite) [armv8.2-a-generic-linux\_android29-llvm\_cpu][default-flags,dt-uk] local\_task(embedded\_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 183.412 (vs. 195.651, 6.26%↓) 184.903 4.155
MobileBertSquad\_int8(tflite) [armv8.2-a-generic-linux\_android29-llvm\_cpu][default-flags,dt-uk] local\_task(embedded\_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 145.617 (vs. 154.850, 5.96%↓) 146.545 3.939

No improved or regressed compilation metrics 🏖️

For more information:

Source Workflow Run

github-actions[bot] avatar May 17 '24 00:05 github-actions[bot]

Also @IanWood1 maybe have a better title for the PR.

MaheshRavishankar avatar May 23 '24 16:05 MaheshRavishankar