AMDMIGraphX icon indicating copy to clipboard operation
AMDMIGraphX copied to clipboard

Better simplification of shape transformation operators such as reshape/transpose/broadcast

Open pfultz2 opened this issue 1 year ago • 9 comments

Currently, in migraphx, we only simplify shape transformation of the same operator such as repeated transpose or reshapes.

However, this will simplify across reshape/transpose/broadcast. It produces a much simpler set of transformations even across our current unit tests.

It also canonicalizes many shape transformations so they will be done with the same set of operators where possible or in the best way to best preserve broadcasting and layout. It will use squeeze and unsqueeze where possible to preserve layout and it will try to end with a broadcast. So starting with a dimension of 2, 32, 1 and doing a multibroadcast[out_lens={2, 32, 256}] -> reshape[dims={2, 32, 16, 16}] now becomes unsqueeze[axes={3}] -> multibroadcast[out_lens={2, 32, 16, 16}].

These simplifications are all done with a data structure that I called shape_transform_descriptor(perhaps this isn't the best name for it, so open to ideas for a better name). This class will record the transformations being done, apply simplifications and then generate the operators needed to produce such transformation.

pfultz2 avatar May 18 '24 03:05 pfultz2

For the reshape operator: do we want it to do the operations that are not covered by the other operators?

What do you mean? reshape can only do reshape, it cant do transpose or broadcast.

pfultz2 avatar May 18 '24 15:05 pfultz2

The unit test failure will most likely be fixed by #3188.

pfultz2 avatar Jun 15 '24 00:06 pfultz2

There are some models that are failing in torch benchmarks with this PR. I am saving off these mxrs in our nas and I'll list them here. /mnt/nas_share/migraphx/models/torch_benchmarks/paul_shape_transform_pr

  • crossvit_9_240
  • fbnetc_100
  • mnasnet_100
  • mobilenetv2_100
  • mobilenetv3_large_100
  • mobilevit_s
  • spnasnet_100
  • tf_efficientnet_b0
  • tinynet_a

I only saw 2 unique errors so most likely they are all failing due to the same bug(s)

shivadbhavsar avatar Jun 20 '24 19:06 shivadbhavsar

There are some models that are failing in torch benchmarks with this PR. I am saving off these mxrs in our nas and I'll list them here. /mnt/nas_share/migraphx/models/torch_benchmarks/paul_shape_transform_pr

* crossvit_9_240

* fbnetc_100

* mnasnet_100

* mobilenetv2_100

* mobilenetv3_large_100

* mobilevit_s

* spnasnet_100

* tf_efficientnet_b0

* tinynet_a

I only saw 2 unique errors so most likely they are all failing due to the same bug(s)

These error happen in simplify_algebra. They look like bugs there and not related to this PR, but this PR probably rewrites it in a way to expose this bug. Probably should open separate tickets to fix those issue. These are the snippets I got that might help describe the error and which class this happens in:

find_concat_op:
@509 = unsqueeze[axes={4},steps={}](@505) -> float_type, {128, 1, 1, 1, 1}, {1, 1, 1, 1, 1}
@511 = unsqueeze[axes={4},steps={}](@506) -> float_type, {1, 3, 1, 1, 1}, {3, 1, 1, 1, 1}
@513 = unsqueeze[axes={4},steps={}](@507) -> float_type, {1, 1, 224, 1, 1}, {224, 224, 1, 1, 1}
@515 = unsqueeze[axes={4},steps={}](@508) -> float_type, {1, 1, 1, 224, 1}, {224, 224, 224, 1, 1}
@510 = multibroadcast[out_lens={128, 3, 224, 224, 1},out_dyn_dims={}](@509) -> float_type, {128, 3, 224, 224, 1}, {1, 0, 0, 0, 1}
@512 = multibroadcast[out_lens={128, 3, 224, 224, 1},out_dyn_dims={}](@511) -> float_type, {128, 3, 224, 224, 1}, {0, 1, 0, 0, 1}
@514 = multibroadcast[out_lens={128, 3, 224, 224, 1},out_dyn_dims={}](@513) -> float_type, {128, 3, 224, 224, 1}, {0, 0, 1, 0, 1}
@516 = multibroadcast[out_lens={128, 3, 224, 224, 1},out_dyn_dims={}](@515) -> float_type, {128, 3, 224, 224, 1}, {0, 0, 0, 1, 1}
@517 = concat[axis=4](@510,@512,@514,@516) -> float_type, {128, 3, 224, 224, 4}, {602112, 200704, 896, 4, 1}

find_splits:
@760 = convolution[padding={0, 0, 0, 0},stride={1, 1},dilation={1, 1},group=1,padding_mode=0](@759,@758) -> half_type, {128, 384, 28, 28}, {301056, 784, 28, 1}
@762 = convolution[padding={0, 0, 0, 0},stride={1, 1},dilation={1, 1},group=1,padding_mode=0](@757,@287) -> half_type, {128, 192, 28, 28}, {784, 100352, 28, 1}
@763 = slice[axes={1},starts={0},ends={192}](@760) -> half_type, {128, 192, 28, 28}, {301056, 784, 28, 1}
@764 = add(@762,@763) -> half_type, {128, 192, 28, 28}, {784, 100352, 28, 1}
@814 = slice[axes={1},starts={192},ends={384}](@760) -> half_type, {128, 192, 28, 28}, {301056, 784, 28, 1}
@815 = convolution[padding={0, 0, 0, 0},stride={1, 1},dilation={1, 1},group=1,padding_mode=0](@809,@282) -> half_type, {128, 192, 28, 28}, {150528, 784, 28, 1}
@816 = add(@814,@815) -> half_type, {128, 192, 28, 28}, {150528, 784, 28, 1}

pfultz2 avatar Jun 20 '24 23:06 pfultz2

Codecov Report

Attention: Patch coverage is 92.32456% with 35 lines in your changes missing coverage. Please review.

Project coverage is 92.17%. Comparing base (05b2ff4) to head (8836cde). Report is 173 commits behind head on develop.

Files with missing lines Patch % Lines
src/shape_transform_descriptor.cpp 91.75% 32 Missing :warning:
src/include/migraphx/algorithm.hpp 90.00% 3 Missing :warning:
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #3104      +/-   ##
===========================================
- Coverage    92.26%   92.17%   -0.10%     
===========================================
  Files          500      503       +3     
  Lines        20057    20444     +387     
===========================================
+ Hits         18506    18844     +338     
- Misses        1551     1600      +49     

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

codecov[bot] avatar Jun 22 '24 02:06 codecov[bot]

Needs a follow up from Shiv

Relevant PRs for these bug fixes have been merged now, the requested changes status can be removed

shivadbhavsar avatar Jul 15 '24 16:07 shivadbhavsar

@pfultz2 need to fix the CI failures

causten avatar Aug 01 '24 16:08 causten

@pfultz2 windows build is failing. Please take a look

umangyadav avatar Aug 09 '24 12:08 umangyadav

Windows CI is passing now.

pfultz2 avatar Aug 19 '24 15:08 pfultz2

Test Batch Rate new
8836cd
Rate old
05b2ff
Diff Compare
torchvision-resnet50 64 3,233.35 3,234.71 -0.04% :white_check_mark:
torchvision-resnet50_fp16 64 6,864.13 6,876.70 -0.18% :white_check_mark:
torchvision-densenet121 32 2,427.52 2,429.30 -0.07% :white_check_mark:
torchvision-densenet121_fp16 32 4,061.20 4,071.07 -0.24% :white_check_mark:
torchvision-inceptionv3 32 1,604.49 1,636.07 -1.93% :white_check_mark:
torchvision-inceptionv3_fp16 32 2,737.71 2,738.91 -0.04% :white_check_mark:
cadene-inceptionv4 16 771.09 769.66 0.19% :white_check_mark:
cadene-resnext64x4 16 806.64 807.66 -0.13% :white_check_mark:
slim-mobilenet 64 7,432.46 7,429.50 0.04% :white_check_mark:
slim-nasnetalarge 64 206.98 207.44 -0.22% :white_check_mark:
slim-resnet50v2 64 3,335.05 3,341.16 -0.18% :white_check_mark:
bert-mrpc-onnx 8 1,150.43 1,148.23 0.19% :white_check_mark:
bert-mrpc-tf 1 310.71 307.31 1.11% :white_check_mark:
pytorch-examples-wlang-gru 1 418.86 422.10 -0.77% :white_check_mark:
pytorch-examples-wlang-lstm 1 381.66 389.61 -2.04% :white_check_mark:
torchvision-resnet50_1 1 788.61 799.66 -1.38% :white_check_mark:
cadene-dpn92_1 1 429.36 394.71 8.78% :high_brightness:
cadene-resnext101_1 1 379.49 380.73 -0.33% :white_check_mark:
onnx-taau-downsample 1 344.23 344.10 0.04% :white_check_mark:
dlrm-criteoterabyte 1 35.02 35.03 -0.04% :white_check_mark:
dlrm-criteoterabyte_fp16 1 57.30 57.31 -0.02% :white_check_mark:
agentmodel 1 9,623.53 8,014.30 20.08% :high_brightness:
unet_fp16 2 57.81 57.84 -0.06% :white_check_mark:
resnet50v1_fp16 1 974.08 1,000.25 -2.62% :white_check_mark:
resnet50v1_int8 1 929.88 933.76 -0.42% :white_check_mark:
bert_base_cased_fp16 64 1,142.28 1,141.87 0.04% :white_check_mark:
bert_large_uncased_fp16 32 350.36 351.72 -0.39% :white_check_mark:
bert_large_fp16 1 210.52 208.57 0.93% :white_check_mark:
distilgpt2_fp16 16 2,152.79 2,151.12 0.08% :white_check_mark:
yolov5s 1 502.43 507.58 -1.01% :white_check_mark:
tinyllama 1 43.43 43.42 0.03% :white_check_mark:
vicuna-fastchat 1 168.21 176.91 -4.92% :red_circle:
whisper-tiny-encoder 1 411.09 410.91 0.04% :white_check_mark:
whisper-tiny-decoder 1 424.33 430.50 -1.43% :white_check_mark:

This build is not recommended to merge :red_circle:

migraphx-bot avatar Aug 19 '24 18:08 migraphx-bot


     :white_check_mark: bert-mrpc-onnx: PASSED: MIGraphX meets tolerance
     :white_check_mark: bert-mrpc-tf: PASSED: MIGraphX meets tolerance
     :white_check_mark: pytorch-examples-wlang-gru: PASSED: MIGraphX meets tolerance
     :white_check_mark: pytorch-examples-wlang-lstm: PASSED: MIGraphX meets tolerance
     :white_check_mark: torchvision-resnet50_1: PASSED: MIGraphX meets tolerance
     :white_check_mark: cadene-dpn92_1: PASSED: MIGraphX meets tolerance
     :white_check_mark: cadene-resnext101_1: PASSED: MIGraphX meets tolerance
     :white_check_mark: dlrm-criteoterabyte: PASSED: MIGraphX meets tolerance
     :white_check_mark: agentmodel: PASSED: MIGraphX meets tolerance
     :white_check_mark: unet: PASSED: MIGraphX meets tolerance
     :white_check_mark: resnet50v1: PASSED: MIGraphX meets tolerance
     :white_check_mark: bert_base_cased_fp16: PASSED: MIGraphX meets tolerance
:red_circle:bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output

     :white_check_mark: bert_large: PASSED: MIGraphX meets tolerance
     :white_check_mark: yolov5s: PASSED: MIGraphX meets tolerance
     :white_check_mark: tinyllama: PASSED: MIGraphX meets tolerance
     :white_check_mark: vicuna-fastchat: PASSED: MIGraphX meets tolerance
     :white_check_mark: whisper-tiny-encoder: PASSED: MIGraphX meets tolerance
     :white_check_mark: whisper-tiny-decoder: PASSED: MIGraphX meets tolerance
     :white_check_mark: distilgpt2_fp16: PASSED: MIGraphX meets tolerance

migraphx-bot avatar Aug 19 '24 18:08 migraphx-bot