AMDMIGraphX icon indicating copy to clipboard operation
AMDMIGraphX copied to clipboard

Add support for GridSample operator

Open gyulaz-htec opened this issue 11 months ago • 8 comments

Add support for GridSample onnx operator. Currently the following feature set is supported:

  • inputs: 4-D inputs
  • align_corners: both 0 and 1
  • mode: linear and nearest
  • padding_mode: zeros, border and reflection Fixes: https://github.com/ROCm/AMDMIGraphX/issues/2923

gyulaz-htec avatar Mar 20 '24 09:03 gyulaz-htec

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 91.90%. Comparing base (bceef13) to head (a1cf519).

:exclamation: Current head a1cf519 differs from pull request most recent head 8bb3d24

Please upload reports for the commit 8bb3d24 to get more accurate results.

Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #2909      +/-   ##
===========================================
+ Coverage    91.82%   91.90%   +0.08%     
===========================================
  Files          486      487       +1     
  Lines        18991    19192     +201     
===========================================
+ Hits         17438    17639     +201     
  Misses        1553     1553              

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

codecov[bot] avatar Mar 20 '24 10:03 codecov[bot]

Test Batch Rate new
b6172d
Rate old
5fcf86
Diff Compare
torchvision-resnet50 64 1,751.55 1,751.74 -0.01% :white_check_mark:
torchvision-resnet50_fp16 64 4,084.47 4,084.34 0.00% :white_check_mark:
torchvision-densenet121 32 1,467.78 1,467.39 0.03% :white_check_mark:
torchvision-densenet121_fp16 32 2,529.08 2,525.45 0.14% :white_check_mark:
torchvision-inceptionv3 32 889.72 889.64 0.01% :white_check_mark:
torchvision-inceptionv3_fp16 32 1,483.73 1,483.57 0.01% :white_check_mark:
cadene-inceptionv4 16 412.38 412.40 -0.01% :white_check_mark:
cadene-resnext64x4 16 419.71 419.50 0.05% :white_check_mark:
slim-mobilenet 64 4,007.09 4,006.71 0.01% :white_check_mark:
slim-nasnetalarge 64 101.02 101.01 0.01% :white_check_mark:
slim-resnet50v2 64 1,680.79 1,680.60 0.01% :white_check_mark:
bert-mrpc-onnx 8 615.83 618.22 -0.39% :white_check_mark:
bert-mrpc-tf 1 277.98 279.81 -0.65% :white_check_mark:
pytorch-examples-wlang-gru 1 320.54 319.57 0.30% :white_check_mark:
pytorch-examples-wlang-lstm 1 288.76 289.36 -0.21% :white_check_mark:
torchvision-resnet50_1 1 471.26 471.89 -0.13% :white_check_mark:
cadene-dpn92_1 1 247.18 247.09 0.04% :white_check_mark:
cadene-resnext101_1 1 203.92 204.23 -0.15% :white_check_mark:
onnx-taau-downsample 1 206.48 206.24 0.11% :white_check_mark:
dlrm-criteoterabyte 1 22.91 22.90 0.04% :white_check_mark:
dlrm-criteoterabyte_fp16 1 42.72 42.73 -0.02% :white_check_mark:
agentmodel 1 6,369.83 6,323.67 0.73% :white_check_mark:
unet_fp16 2 34.18 34.21 -0.07% :white_check_mark:
resnet50v1_fp16 1 574.81 589.36 -2.47% :white_check_mark:
resnet50v1_int8 1 578.26 573.50 0.83% :white_check_mark:
bert_base_cased_fp16 64 646.28 646.33 -0.01% :white_check_mark:
bert_large_uncased_fp16 32 199.02 198.99 0.02% :white_check_mark:
bert_large_fp16 1 117.21 117.54 -0.27% :white_check_mark:
distilgpt2_fp16 16 1,212.43 1,211.40 0.09% :white_check_mark:
yolov5s 1 301.12 301.25 -0.04% :white_check_mark:
tinyllama 1 23.33 23.34 -0.03% :white_check_mark:
vicuna-fastchat 1 133.85 133.65 0.15% :white_check_mark:
whisper-tiny-encoder 1 244.42 244.39 0.01% :white_check_mark:
whisper-tiny-decoder 1 256.29 256.66 -0.14% :white_check_mark:

This build is OK for merge :white_check_mark:

migraphx-bot avatar Mar 20 '24 10:03 migraphx-bot


     :white_check_mark: bert-mrpc-onnx: PASSED: MIGraphX meets tolerance
     :white_check_mark: bert-mrpc-tf: PASSED: MIGraphX meets tolerance
     :white_check_mark: pytorch-examples-wlang-gru: PASSED: MIGraphX meets tolerance
     :white_check_mark: pytorch-examples-wlang-lstm: PASSED: MIGraphX meets tolerance
     :white_check_mark: torchvision-resnet50_1: PASSED: MIGraphX meets tolerance
     :white_check_mark: cadene-dpn92_1: PASSED: MIGraphX meets tolerance
     :white_check_mark: cadene-resnext101_1: PASSED: MIGraphX meets tolerance
     :white_check_mark: dlrm-criteoterabyte: PASSED: MIGraphX meets tolerance
     :white_check_mark: agentmodel: PASSED: MIGraphX meets tolerance
     :white_check_mark: unet: PASSED: MIGraphX meets tolerance
     :white_check_mark: resnet50v1: PASSED: MIGraphX meets tolerance
     :white_check_mark: bert_base_cased_fp16: PASSED: MIGraphX meets tolerance
:red_circle:bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output

     :white_check_mark: bert_large: PASSED: MIGraphX meets tolerance
     :white_check_mark: yolov5s: PASSED: MIGraphX meets tolerance
     :white_check_mark: tinyllama: PASSED: MIGraphX meets tolerance
     :white_check_mark: vicuna-fastchat: PASSED: MIGraphX meets tolerance
     :white_check_mark: whisper-tiny-encoder: PASSED: MIGraphX meets tolerance
     :white_check_mark: whisper-tiny-decoder: PASSED: MIGraphX meets tolerance
     :white_check_mark: distilgpt2_fp16: PASSED: MIGraphX meets tolerance

migraphx-bot avatar Mar 20 '24 10:03 migraphx-bot

@causten The PR is ready for review ~~This sould not be marged yet, because I've found an accuracy issue with fuse_pointwise pass with the GPU target when computing GridSample results (with linear inetrploation). Collected my findings on that issue in https://github.com/ROCm/AMDMIGraphX/issues/2923~~

gyulaz-htec avatar Mar 25 '24 10:03 gyulaz-htec

I had to exclude logical_and and where operators from pointwise fusion, so the GPU tests are passing now. https://github.com/ROCm/AMDMIGraphX/issues/2923 is fixed by that.

gyulaz-htec avatar Mar 28 '24 12:03 gyulaz-htec

@TedThemistokleous @pfultz2 I've adressed the majority of you commets except the Look into alternative to avoid O(n^4) updates and ops. part.

gyulaz-htec avatar Apr 11 '24 14:04 gyulaz-htec

@TedThemistokleous @pfultz2 I've managed to move out the op creation and updates outside the O(n^4) part, so the heavy lifting is not in that part anymore. The remaining code there is only for generating literals for gathernd indices.

gyulaz-htec avatar Apr 15 '24 15:04 gyulaz-htec

@pfultz2 I've removed the fuse pointwise related changes from the PR in favor of https://github.com/ROCm/AMDMIGraphX/pull/3054

gyulaz-htec avatar May 08 '24 09:05 gyulaz-htec

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 92.05%. Comparing base (5fcf86e) to head (b6172d3). Report is 154 commits behind head on develop.

Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #2909      +/-   ##
===========================================
+ Coverage    91.97%   92.05%   +0.08%     
===========================================
  Files          489      490       +1     
  Lines        19398    19599     +201     
===========================================
+ Hits         17841    18042     +201     
  Misses        1557     1557              

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

codecov-commenter avatar May 24 '24 05:05 codecov-commenter

@gyulaz-htec you don't need hit the "update branch" button. Once the review are complete I'll handle it

causten avatar Jun 11 '24 20:06 causten

I would like to see additional documentation and possible refactoring since the parse_gridsample.cpp is quite large. Not blocking for merge.

kahmed10 avatar Jun 19 '24 04:06 kahmed10