AMDMIGraphX icon indicating copy to clipboard operation
AMDMIGraphX copied to clipboard

Allow multiple outputs for the MLIR + Pointwise fusions

Open umangyadav opened this issue 1 year ago • 10 comments

At present when doing mlir + pointwise fusion, It is not checking if "conv/gemm" has only use or not.

Therefore for the contrived case, it can lead to pessimization where it is recomputing "dot/conv" again.

This PR fixes that.

Currently MLIR compilation is broken for the multiple outputs, Fix for that would come from rocMLIR. This PR wouldn't change after rocMLIR fix. https://github.com/ROCm/rocMLIR-internal/issues/1546

umangyadav avatar Jul 23 '24 15:07 umangyadav

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 92.26%. Comparing base (82e534c) to head (d16de2c). Report is 153 commits behind head on develop.

Additional details and impacted files
@@           Coverage Diff            @@
##           develop    #3299   +/-   ##
========================================
  Coverage    92.26%   92.26%           
========================================
  Files          499      499           
  Lines        20020    20020           
========================================
  Hits         18471    18471           
  Misses        1549     1549           

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

codecov[bot] avatar Jul 23 '24 16:07 codecov[bot]

Will this handle cases like this?

┌─┐   
│a│   
└┬┘   
┌▽──┐ 
│b  │ 
└┬─┬┘ 
 │┌▽┐ 
 ││c│ 
 │└┬┘ 
 │┌▽─┐
 ││d │
 │└┬─┘
┌▽─▽┐ 
│e  │ 
└───┘ 

Especially if c or d is not fusible with mlir.

pfultz2 avatar Jul 23 '24 19:07 pfultz2

Will this handle cases like this?

┌─┐   
│a│   
└┬┘   
┌▽──┐ 
│b  │ 
└┬─┬┘ 
 │┌▽┐ 
 ││c│ 
 │└┬┘ 
 │┌▽─┐
 ││d │
 │└┬─┘
┌▽─▽┐ 
│e  │ 
└───┘ 

Especially if c or d is not fusible with mlir.

I don't think that will be handled well even with existing logic on develop branch or with this PR.

It is implictly assumed that only ONE of the inputs to the pointwise op would be produced from GEMM/conv that is being fused. Rest must be produced by other ops.

umangyadav avatar Jul 23 '24 19:07 umangyadav

I don't think that will be handled well even with existing logic on develop branch or with this PR.

It is implictly assumed that only ONE of the inputs to the pointwise op would be produced from GEMM/conv that is being fused. Rest must be produced by other ops.

What about if b is convolution and c and e are pointwise ops but d is not(like pooiling or something else)? Currently we dont check if b is used once so it would create two convolutions.

However, as I understand it, this PR would try to create multiple outputs instead. So would it fuse both c and e with the convolution? Or would it just fuse c and output the convolution to be used later?

pfultz2 avatar Jul 24 '24 15:07 pfultz2

So would it fuse both c and e with the convolution? Or would it just fuse c and output the convolution to be used later?

Following with same example where b is convolution c and e are poitnwise but d is a pooling.

Whether b + e is matched first or b + c is matched first, it depends on how graph is traversed.

if b + e is matched then it would produce incorrect result or crash. That case is not handled . Because as i mentioned earlier that it is implictly assumed that only one of the inputs to the pointwise must be produced from Convolution. Rest must be produced from some other ops.

If b + c is matched and fused then it would work as expected. b would be part of multiple outputs and then later consumed by e . This PR would handle this case very well.

This PR is not asserting or handling the implicit assumption. That case needs to be handled in general for many other matcheers inside fuse_mlir pass.

umangyadav avatar Jul 25 '24 15:07 umangyadav

If b + c is matched and fused then it would work as expected. b would be part of multiple outputs and then later consumed by e . This PR would handle this case very well.

Are you able to add a test case for this scenario so we dont break this in the future?

pfultz2 avatar Jul 25 '24 16:07 pfultz2

Are you able to add a test case for this scenario so we dont break this in the future?

Both the unit-tests and a verify test, that i added are of this type.

umangyadav avatar Jul 29 '24 12:07 umangyadav

Verify test is failing here: fix for that is here #3316

umangyadav avatar Jul 29 '24 13:07 umangyadav

Test Batch Rate new
7e83db
Rate old
403ee8
Diff Compare
torchvision-resnet50 64 1,758.38 1,766.16 -0.44% :white_check_mark:
torchvision-resnet50_fp16 64 4,169.25 4,181.07 -0.28% :white_check_mark:
torchvision-densenet121 32 1,469.43 1,477.73 -0.56% :white_check_mark:
torchvision-densenet121_fp16 32 2,548.56 2,555.89 -0.29% :white_check_mark:
torchvision-inceptionv3 32 888.90 892.55 -0.41% :white_check_mark:
torchvision-inceptionv3_fp16 32 1,489.52 1,493.81 -0.29% :white_check_mark:
cadene-inceptionv4 16 412.53 413.99 -0.35% :white_check_mark:
cadene-resnext64x4 16 422.67 423.86 -0.28% :white_check_mark:
slim-mobilenet 64 4,045.38 4,062.10 -0.41% :white_check_mark:
slim-nasnetalarge 64 113.59 101.27 12.17% :high_brightness:
slim-resnet50v2 64 1,887.47 1,686.40 11.92% :high_brightness:
bert-mrpc-onnx 8 612.29 614.96 -0.43% :white_check_mark:
bert-mrpc-tf 1 172.48 172.13 0.20% :white_check_mark:
pytorch-examples-wlang-gru 1 323.47 322.73 0.23% :white_check_mark:
pytorch-examples-wlang-lstm 1 290.33 294.86 -1.53% :white_check_mark:
torchvision-resnet50_1 1 475.05 475.38 -0.07% :white_check_mark:
cadene-dpn92_1 1 249.39 249.88 -0.19% :white_check_mark:
cadene-resnext101_1 1 205.14 206.13 -0.48% :white_check_mark:
onnx-taau-downsample 1 205.57 206.21 -0.31% :white_check_mark:
dlrm-criteoterabyte 1 22.98 23.04 -0.29% :white_check_mark:
dlrm-criteoterabyte_fp16 1 43.66 43.78 -0.29% :white_check_mark:
agentmodel 1 7,567.39 6,274.16 20.61% :high_brightness:
unet_fp16 2 34.21 34.27 -0.19% :white_check_mark:
resnet50v1_fp16 1 570.88 595.26 -4.10% :red_circle:
resnet50v1_int8 1 589.42 576.83 2.18% :white_check_mark:
bert_base_cased_fp16 64 644.06 647.36 -0.51% :white_check_mark:
bert_large_uncased_fp16 32 197.95 198.89 -0.47% :white_check_mark:
bert_large_fp16 1 116.92 117.05 -0.11% :white_check_mark:
distilgpt2_fp16 16 1,219.21 1,225.90 -0.55% :white_check_mark:
yolov5s 1 nan 301.73 nan% :x:
tinyllama 1 23.23 23.32 -0.37% :white_check_mark:
vicuna-fastchat 1 132.02 133.38 -1.02% :white_check_mark:
whisper-tiny-encoder 1 243.24 244.57 -0.54% :white_check_mark:
whisper-tiny-decoder 1 255.76 256.29 -0.21% :white_check_mark:

This build is not recommended to merge :red_circle:

migraphx-bot avatar Jul 31 '24 15:07 migraphx-bot


     :white_check_mark: bert-mrpc-onnx: PASSED: MIGraphX meets tolerance
     :white_check_mark: bert-mrpc-tf: PASSED: MIGraphX meets tolerance
     :white_check_mark: pytorch-examples-wlang-gru: PASSED: MIGraphX meets tolerance
     :white_check_mark: pytorch-examples-wlang-lstm: PASSED: MIGraphX meets tolerance
     :white_check_mark: torchvision-resnet50_1: PASSED: MIGraphX meets tolerance
     :white_check_mark: cadene-dpn92_1: PASSED: MIGraphX meets tolerance
     :white_check_mark: cadene-resnext101_1: PASSED: MIGraphX meets tolerance
     :white_check_mark: dlrm-criteoterabyte: PASSED: MIGraphX meets tolerance
     :white_check_mark: agentmodel: PASSED: MIGraphX meets tolerance
     :white_check_mark: unet: PASSED: MIGraphX meets tolerance
     :white_check_mark: resnet50v1: PASSED: MIGraphX meets tolerance
     :white_check_mark: bert_base_cased_fp16: PASSED: MIGraphX meets tolerance
:red_circle:bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output

     :white_check_mark: bert_large: PASSED: MIGraphX meets tolerance
:x:yolov5s: ERROR - check error output

     :white_check_mark: tinyllama: PASSED: MIGraphX meets tolerance
     :white_check_mark: vicuna-fastchat: PASSED: MIGraphX meets tolerance
     :white_check_mark: whisper-tiny-encoder: PASSED: MIGraphX meets tolerance
     :white_check_mark: whisper-tiny-decoder: PASSED: MIGraphX meets tolerance
     :white_check_mark: distilgpt2_fp16: PASSED: MIGraphX meets tolerance

migraphx-bot avatar Jul 31 '24 15:07 migraphx-bot