AMDMIGraphX icon indicating copy to clipboard operation
AMDMIGraphX copied to clipboard

Improvements to inner_broadcast

Open pfultz2 opened this issue 9 months ago • 4 comments

This fixes a bug with inner_broadcast and improves the handling for more cases. For the simple cases with the same broadcasts and dimension, it will just insert the broadcasting afterwards.

However, for the case with different dimensions and broadcasts, it will find the axes that are being broadcasted and not broadcasted and it will insert squeezes and unsqueezes as necessary. Before we would just use squeeze with empty axes, which essentially removed all single dimensions, but now we explicitly call out which axes we want to remove.

pfultz2 avatar Apr 25 '24 22:04 pfultz2

Test Batch Rate new
b5db84
Rate old
06eef0
Diff Compare
torchvision-resnet50 64 2,789.74 2,790.67 -0.03% :white_check_mark:
torchvision-resnet50_fp16 64 6,207.99 6,206.60 0.02% :white_check_mark:
torchvision-densenet121 32 2,075.90 2,093.23 -0.83% :white_check_mark:
torchvision-densenet121_fp16 32 3,610.04 3,618.84 -0.24% :white_check_mark:
torchvision-inceptionv3 32 1,597.44 1,596.91 0.03% :white_check_mark:
torchvision-inceptionv3_fp16 32 2,558.49 2,556.16 0.09% :white_check_mark:
cadene-inceptionv4 16 716.26 716.53 -0.04% :white_check_mark:
cadene-resnext64x4 16 678.02 678.05 -0.00% :white_check_mark:
slim-mobilenet 64 5,815.86 5,820.72 -0.08% :white_check_mark:
slim-nasnetalarge 64 154.24 154.28 -0.03% :white_check_mark:
slim-resnet50v2 64 2,578.85 2,580.94 -0.08% :white_check_mark:
bert-mrpc-onnx 8 969.61 969.84 -0.02% :white_check_mark:
bert-mrpc-tf 1 408.13 412.65 -1.09% :white_check_mark:
pytorch-examples-wlang-gru 1 395.18 392.80 0.61% :white_check_mark:
pytorch-examples-wlang-lstm 1 369.26 371.12 -0.50% :white_check_mark:
torchvision-resnet50_1 1 604.97 600.81 0.69% :white_check_mark:
cadene-dpn92_1 1 386.64 384.63 0.52% :white_check_mark:
cadene-resnext101_1 1 323.01 327.21 -1.28% :white_check_mark:
onnx-taau-downsample 1 306.95 306.71 0.08% :white_check_mark:
dlrm-criteoterabyte 1 28.54 28.56 -0.09% :white_check_mark:
dlrm-criteoterabyte_fp16 1 47.15 47.18 -0.06% :white_check_mark:
agentmodel 1 7,875.19 7,343.44 7.24% :high_brightness:
unet_fp16 2 57.60 57.44 0.29% :white_check_mark:
resnet50v1_fp16 1 896.31 896.02 0.03% :white_check_mark:
resnet50v1_int8 1 784.44 801.08 -2.08% :white_check_mark:
bert_base_cased_fp16 64 1,022.25 1,022.20 0.00% :white_check_mark:
bert_large_uncased_fp16 32 299.15 299.11 0.01% :white_check_mark:
bert_large_fp16 1 158.70 156.32 1.52% :white_check_mark:
distilgpt2_fp16 16 1,834.13 1,831.35 0.15% :white_check_mark:
yolov5s 1 469.85 474.82 -1.05% :white_check_mark:
tinyllama 1 33.01 33.01 0.01% :white_check_mark:
vicuna-fastchat 1 159.67 159.87 -0.12% :white_check_mark:
whisper-tiny-encoder 1 352.14 352.38 -0.07% :white_check_mark:
whisper-tiny-decoder 1 401.24 396.47 1.20% :white_check_mark:

Check results before merge :high_brightness:

migraphx-bot avatar Apr 26 '24 00:04 migraphx-bot


     :white_check_mark: bert-mrpc-onnx: PASSED: MIGraphX meets tolerance
     :white_check_mark: bert-mrpc-tf: PASSED: MIGraphX meets tolerance
     :white_check_mark: pytorch-examples-wlang-gru: PASSED: MIGraphX meets tolerance
     :white_check_mark: pytorch-examples-wlang-lstm: PASSED: MIGraphX meets tolerance
     :white_check_mark: torchvision-resnet50_1: PASSED: MIGraphX meets tolerance
     :white_check_mark: cadene-dpn92_1: PASSED: MIGraphX meets tolerance
     :white_check_mark: cadene-resnext101_1: PASSED: MIGraphX meets tolerance
     :white_check_mark: dlrm-criteoterabyte: PASSED: MIGraphX meets tolerance
     :white_check_mark: agentmodel: PASSED: MIGraphX meets tolerance
     :white_check_mark: unet: PASSED: MIGraphX meets tolerance
     :white_check_mark: resnet50v1: PASSED: MIGraphX meets tolerance
     :white_check_mark: bert_base_cased_fp16: PASSED: MIGraphX meets tolerance
:red_circle:bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output

     :white_check_mark: bert_large: PASSED: MIGraphX meets tolerance
     :white_check_mark: yolov5s: PASSED: MIGraphX meets tolerance
     :white_check_mark: tinyllama: PASSED: MIGraphX meets tolerance
     :white_check_mark: vicuna-fastchat: PASSED: MIGraphX meets tolerance
     :white_check_mark: whisper-tiny-encoder: PASSED: MIGraphX meets tolerance
     :white_check_mark: whisper-tiny-decoder: PASSED: MIGraphX meets tolerance
     :white_check_mark: distilgpt2_fp16: PASSED: MIGraphX meets tolerance

migraphx-bot avatar Apr 26 '24 00:04 migraphx-bot

This PR brings up a gripe I've had with broadcast vs. multibroadcast. It's quite unwieldy to have these two slightly different broadcasting instructions that sound like they would do the same thing but actually work quite differently. I think it's better to heavily prefer one over the other (either one). That way we have consistent shape broadcasting rules.

Yea I do thinks its annoying to have two different. Originally, broadcast was used to broadcast a single dimension, and multibroadcast was used when broadcasting a tensor with multiple dimensions. I didnt realize until later that broadcast could be used for the same thing as multibroadcast, had I knew that back then I would've probably rejected the multibroadcast operator. I would love to get rid of it completely but there is still a lot of places that use it and rely on it(like fuse_reduce that only looks for multibroadcastt here), so it would be a big refactoring.

pfultz2 avatar Apr 30 '24 02:04 pfultz2

Codecov Report

Attention: Patch coverage is 98.00000% with 2 lines in your changes are missing coverage. Please review.

Project coverage is 91.81%. Comparing base (06eef05) to head (e82db90).

:exclamation: Current head e82db90 differs from pull request most recent head b5db848. Consider uploading reports for the commit b5db848 to get more accurate results

Files Patch % Lines
src/simplify_algebra.cpp 98.00% 2 Missing :warning:
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #3002      +/-   ##
===========================================
+ Coverage    91.80%   91.81%   +0.01%     
===========================================
  Files          486      486              
  Lines        18867    18929      +62     
===========================================
+ Hits         17320    17379      +59     
- Misses        1547     1550       +3     

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

codecov[bot] avatar May 01 '24 00:05 codecov[bot]

Added tests for better coverage.

pfultz2 avatar May 10 '24 19:05 pfultz2