Better simplification of shape transformation operators such as reshape/transpose/broadcast
Currently, in migraphx, we only simplify shape transformation of the same operator such as repeated transpose or reshapes.
However, this will simplify across reshape/transpose/broadcast. It produces a much simpler set of transformations even across our current unit tests.
It also canonicalizes many shape transformations so they will be done with the same set of operators where possible or in the best way to best preserve broadcasting and layout. It will use squeeze and unsqueeze where possible to preserve layout and it will try to end with a broadcast. So starting with a dimension of 2, 32, 1 and doing a multibroadcast[out_lens={2, 32, 256}] -> reshape[dims={2, 32, 16, 16}] now becomes unsqueeze[axes={3}] -> multibroadcast[out_lens={2, 32, 16, 16}].
These simplifications are all done with a data structure that I called shape_transform_descriptor(perhaps this isn't the best name for it, so open to ideas for a better name). This class will record the transformations being done, apply simplifications and then generate the operators needed to produce such transformation.
For the reshape operator: do we want it to do the operations that are not covered by the other operators?
What do you mean? reshape can only do reshape, it cant do transpose or broadcast.
The unit test failure will most likely be fixed by #3188.
There are some models that are failing in torch benchmarks with this PR. I am saving off these mxrs in our nas and I'll list them here.
/mnt/nas_share/migraphx/models/torch_benchmarks/paul_shape_transform_pr
- crossvit_9_240
- fbnetc_100
- mnasnet_100
- mobilenetv2_100
- mobilenetv3_large_100
- mobilevit_s
- spnasnet_100
- tf_efficientnet_b0
- tinynet_a
I only saw 2 unique errors so most likely they are all failing due to the same bug(s)
There are some models that are failing in torch benchmarks with this PR. I am saving off these mxrs in our nas and I'll list them here.
/mnt/nas_share/migraphx/models/torch_benchmarks/paul_shape_transform_pr* crossvit_9_240 * fbnetc_100 * mnasnet_100 * mobilenetv2_100 * mobilenetv3_large_100 * mobilevit_s * spnasnet_100 * tf_efficientnet_b0 * tinynet_aI only saw 2 unique errors so most likely they are all failing due to the same bug(s)
These error happen in simplify_algebra. They look like bugs there and not related to this PR, but this PR probably rewrites it in a way to expose this bug. Probably should open separate tickets to fix those issue. These are the snippets I got that might help describe the error and which class this happens in:
find_concat_op:
@509 = unsqueeze[axes={4},steps={}](@505) -> float_type, {128, 1, 1, 1, 1}, {1, 1, 1, 1, 1}
@511 = unsqueeze[axes={4},steps={}](@506) -> float_type, {1, 3, 1, 1, 1}, {3, 1, 1, 1, 1}
@513 = unsqueeze[axes={4},steps={}](@507) -> float_type, {1, 1, 224, 1, 1}, {224, 224, 1, 1, 1}
@515 = unsqueeze[axes={4},steps={}](@508) -> float_type, {1, 1, 1, 224, 1}, {224, 224, 224, 1, 1}
@510 = multibroadcast[out_lens={128, 3, 224, 224, 1},out_dyn_dims={}](@509) -> float_type, {128, 3, 224, 224, 1}, {1, 0, 0, 0, 1}
@512 = multibroadcast[out_lens={128, 3, 224, 224, 1},out_dyn_dims={}](@511) -> float_type, {128, 3, 224, 224, 1}, {0, 1, 0, 0, 1}
@514 = multibroadcast[out_lens={128, 3, 224, 224, 1},out_dyn_dims={}](@513) -> float_type, {128, 3, 224, 224, 1}, {0, 0, 1, 0, 1}
@516 = multibroadcast[out_lens={128, 3, 224, 224, 1},out_dyn_dims={}](@515) -> float_type, {128, 3, 224, 224, 1}, {0, 0, 0, 1, 1}
@517 = concat[axis=4](@510,@512,@514,@516) -> float_type, {128, 3, 224, 224, 4}, {602112, 200704, 896, 4, 1}
find_splits:
@760 = convolution[padding={0, 0, 0, 0},stride={1, 1},dilation={1, 1},group=1,padding_mode=0](@759,@758) -> half_type, {128, 384, 28, 28}, {301056, 784, 28, 1}
@762 = convolution[padding={0, 0, 0, 0},stride={1, 1},dilation={1, 1},group=1,padding_mode=0](@757,@287) -> half_type, {128, 192, 28, 28}, {784, 100352, 28, 1}
@763 = slice[axes={1},starts={0},ends={192}](@760) -> half_type, {128, 192, 28, 28}, {301056, 784, 28, 1}
@764 = add(@762,@763) -> half_type, {128, 192, 28, 28}, {784, 100352, 28, 1}
@814 = slice[axes={1},starts={192},ends={384}](@760) -> half_type, {128, 192, 28, 28}, {301056, 784, 28, 1}
@815 = convolution[padding={0, 0, 0, 0},stride={1, 1},dilation={1, 1},group=1,padding_mode=0](@809,@282) -> half_type, {128, 192, 28, 28}, {150528, 784, 28, 1}
@816 = add(@814,@815) -> half_type, {128, 192, 28, 28}, {150528, 784, 28, 1}
Codecov Report
Attention: Patch coverage is 92.32456% with 35 lines in your changes missing coverage. Please review.
Project coverage is 92.17%. Comparing base (
05b2ff4) to head (8836cde). Report is 173 commits behind head on develop.
| Files with missing lines | Patch % | Lines |
|---|---|---|
| src/shape_transform_descriptor.cpp | 91.75% | 32 Missing :warning: |
| src/include/migraphx/algorithm.hpp | 90.00% | 3 Missing :warning: |
Additional details and impacted files
@@ Coverage Diff @@
## develop #3104 +/- ##
===========================================
- Coverage 92.26% 92.17% -0.10%
===========================================
Files 500 503 +3
Lines 20057 20444 +387
===========================================
+ Hits 18506 18844 +338
- Misses 1551 1600 +49
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
Needs a follow up from Shiv
Relevant PRs for these bug fixes have been merged now, the requested changes status can be removed
@pfultz2 need to fix the CI failures
@pfultz2 windows build is failing. Please take a look
Windows CI is passing now.
| Test | Batch | Rate new 8836cd |
Rate old 05b2ff |
Diff | Compare |
|---|---|---|---|---|---|
| torchvision-resnet50 | 64 | 3,233.35 | 3,234.71 | -0.04% | :white_check_mark: |
| torchvision-resnet50_fp16 | 64 | 6,864.13 | 6,876.70 | -0.18% | :white_check_mark: |
| torchvision-densenet121 | 32 | 2,427.52 | 2,429.30 | -0.07% | :white_check_mark: |
| torchvision-densenet121_fp16 | 32 | 4,061.20 | 4,071.07 | -0.24% | :white_check_mark: |
| torchvision-inceptionv3 | 32 | 1,604.49 | 1,636.07 | -1.93% | :white_check_mark: |
| torchvision-inceptionv3_fp16 | 32 | 2,737.71 | 2,738.91 | -0.04% | :white_check_mark: |
| cadene-inceptionv4 | 16 | 771.09 | 769.66 | 0.19% | :white_check_mark: |
| cadene-resnext64x4 | 16 | 806.64 | 807.66 | -0.13% | :white_check_mark: |
| slim-mobilenet | 64 | 7,432.46 | 7,429.50 | 0.04% | :white_check_mark: |
| slim-nasnetalarge | 64 | 206.98 | 207.44 | -0.22% | :white_check_mark: |
| slim-resnet50v2 | 64 | 3,335.05 | 3,341.16 | -0.18% | :white_check_mark: |
| bert-mrpc-onnx | 8 | 1,150.43 | 1,148.23 | 0.19% | :white_check_mark: |
| bert-mrpc-tf | 1 | 310.71 | 307.31 | 1.11% | :white_check_mark: |
| pytorch-examples-wlang-gru | 1 | 418.86 | 422.10 | -0.77% | :white_check_mark: |
| pytorch-examples-wlang-lstm | 1 | 381.66 | 389.61 | -2.04% | :white_check_mark: |
| torchvision-resnet50_1 | 1 | 788.61 | 799.66 | -1.38% | :white_check_mark: |
| cadene-dpn92_1 | 1 | 429.36 | 394.71 | 8.78% | :high_brightness: |
| cadene-resnext101_1 | 1 | 379.49 | 380.73 | -0.33% | :white_check_mark: |
| onnx-taau-downsample | 1 | 344.23 | 344.10 | 0.04% | :white_check_mark: |
| dlrm-criteoterabyte | 1 | 35.02 | 35.03 | -0.04% | :white_check_mark: |
| dlrm-criteoterabyte_fp16 | 1 | 57.30 | 57.31 | -0.02% | :white_check_mark: |
| agentmodel | 1 | 9,623.53 | 8,014.30 | 20.08% | :high_brightness: |
| unet_fp16 | 2 | 57.81 | 57.84 | -0.06% | :white_check_mark: |
| resnet50v1_fp16 | 1 | 974.08 | 1,000.25 | -2.62% | :white_check_mark: |
| resnet50v1_int8 | 1 | 929.88 | 933.76 | -0.42% | :white_check_mark: |
| bert_base_cased_fp16 | 64 | 1,142.28 | 1,141.87 | 0.04% | :white_check_mark: |
| bert_large_uncased_fp16 | 32 | 350.36 | 351.72 | -0.39% | :white_check_mark: |
| bert_large_fp16 | 1 | 210.52 | 208.57 | 0.93% | :white_check_mark: |
| distilgpt2_fp16 | 16 | 2,152.79 | 2,151.12 | 0.08% | :white_check_mark: |
| yolov5s | 1 | 502.43 | 507.58 | -1.01% | :white_check_mark: |
| tinyllama | 1 | 43.43 | 43.42 | 0.03% | :white_check_mark: |
| vicuna-fastchat | 1 | 168.21 | 176.91 | -4.92% | :red_circle: |
| whisper-tiny-encoder | 1 | 411.09 | 410.91 | 0.04% | :white_check_mark: |
| whisper-tiny-decoder | 1 | 424.33 | 430.50 | -1.43% | :white_check_mark: |
This build is not recommended to merge :red_circle:
:red_circle:bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output