Prevent collapsing batch dims in dot ops with constants
This simplifies many reshape -> dot -> reshape patterns that are not handled by the find_reshape_reshape_dot pass (ie. in gemms where one input is a constant).
This also simplifies the reshape found in #2736
Codecov Report
Attention: Patch coverage is 97.72727% with 1 line in your changes missing coverage. Please review.
Project coverage is 91.93%. Comparing base (
30cab64) to head (92b2246). Report is 150 commits behind head on develop.
| Files with missing lines | Patch % | Lines |
|---|---|---|
| src/simplify_reshapes.cpp | 97.72% | 1 Missing :warning: |
Additional details and impacted files
@@ Coverage Diff @@
## develop #2823 +/- ##
========================================
Coverage 91.92% 91.93%
========================================
Files 489 489
Lines 19275 19301 +26
========================================
+ Hits 17719 17744 +25
- Misses 1556 1557 +1
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
SDXL Pref results for reference:
Torch-MIGraphX (end to end): Before PR: 2850 ms With PR: 2801 ms
ONNX Unet (4x attn trim): Before PR: 5.54 ms After PR: 5.52 ms
As expected, it doesnt affect the onnx version much because there is an extra convert in the middle. Once the convert is handled, the perf number reduces to 5.47ms.
| Test | Batch | Rate new 048bcd |
Rate old dc028d |
Diff | Compare |
|---|---|---|---|---|---|
| torchvision-resnet50 | 64 | 1,703.96 | 1,489.70 | 14.38% | :high_brightness: |
| torchvision-resnet50_fp16 | 64 | 3,796.06 | 1,346.10 | 182.00% | :high_brightness: |
| torchvision-densenet121 | 32 | 1,445.65 | 1,440.50 | 0.36% | :white_check_mark: |
| torchvision-densenet121_fp16 | 32 | 2,424.59 | 2,416.72 | 0.33% | :white_check_mark: |
| torchvision-inceptionv3 | 32 | 878.68 | 881.27 | -0.29% | :white_check_mark: |
| torchvision-inceptionv3_fp16 | 32 | 1,408.15 | 1,406.66 | 0.11% | :white_check_mark: |
| cadene-inceptionv4 | 16 | 406.42 | 404.25 | 0.54% | :white_check_mark: |
| cadene-resnext64x4 | 16 | 411.54 | 410.23 | 0.32% | :white_check_mark: |
| slim-mobilenet | 64 | 3,805.08 | 3,794.28 | 0.28% | :white_check_mark: |
| slim-nasnetalarge | 64 | 96.56 | 94.95 | 1.69% | :white_check_mark: |
| slim-resnet50v2 | 64 | 1,643.38 | 1,620.87 | 1.39% | :white_check_mark: |
| bert-mrpc-onnx | 8 | 586.19 | 591.10 | -0.83% | :white_check_mark: |
| bert-mrpc-tf | 1 | 288.53 | 289.30 | -0.27% | :white_check_mark: |
| pytorch-examples-wlang-gru | 1 | 336.52 | 378.53 | -11.10% | :red_circle: |
| pytorch-examples-wlang-lstm | 1 | 303.46 | 266.28 | 13.96% | :high_brightness: |
| torchvision-resnet50_1 | 1 | 440.60 | 369.37 | 19.28% | :high_brightness: |
| cadene-dpn92_1 | 1 | 244.39 | 233.66 | 4.59% | :high_brightness: |
| cadene-resnext101_1 | 1 | 187.18 | 189.16 | -1.04% | :white_check_mark: |
| onnx-taau-downsample | 1 | 203.30 | 183.13 | 11.02% | :high_brightness: |
| dlrm-criteoterabyte | 1 | 22.19 | 21.99 | 0.92% | :white_check_mark: |
| dlrm-criteoterabyte_fp16 | 1 | 41.47 | 41.43 | 0.10% | :white_check_mark: |
| agentmodel | 1 | 6,060.17 | 6,337.70 | -4.38% | :red_circle: |
| unet_fp16 | 2 | 33.34 | 33.63 | -0.85% | :white_check_mark: |
| resnet50v1_fp16 | 1 | 566.17 | 521.53 | 8.56% | :high_brightness: |
| resnet50v1_int8 | 1 | 462.93 | 452.53 | 2.30% | :white_check_mark: |
| bert_base_cased_fp16 | 64 | 617.49 | 620.67 | -0.51% | :white_check_mark: |
| bert_large_uncased_fp16 | 32 | 192.68 | 193.85 | -0.61% | :white_check_mark: |
| bert_large_fp16 | 1 | 103.66 | 103.88 | -0.21% | :white_check_mark: |
| distilgpt2_fp16 | 16 | 1,150.51 | 1,187.83 | -3.14% | :red_circle: |
| yolov5s | 1 | 297.67 | 297.39 | 0.09% | :white_check_mark: |
| tinyllama | 1 | 23.21 | 23.34 | -0.53% | :white_check_mark: |
| vicuna-fastchat | 1 | 133.22 | 132.19 | 0.78% | :white_check_mark: |
| whisper-tiny-encoder | 1 | 240.05 | 240.52 | -0.19% | :white_check_mark: |
| whisper-tiny-decoder | 1 | 244.55 | 245.42 | -0.35% | :white_check_mark: |
This build is not recommended to merge :red_circle:
:red_circle:bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output