Implement Cumulative reduction (max, min, sum, prod) forward with small last dim
This PR is a continuation of PR #3182. Accidently, I have closed the older PR, and then made change to the working branch, which makes me cannot reopen the older once. There are not many comments in that PR, I hope it doesn't interrupt your reviewing process. And again, sorry for this Inconvenience.
- Added cumulative reduction forward operation and kernel with solver, support binary operators (max, min, sum, prod). This operation equivalent to cummax, cummin, cumsum, cumprod in Pytorch.
- Added driver test and gtest for cumulative reduction.
- New API is guarded by MIOPEN_BETA_API macro.
- Compared to ROCm pytorch, there is a performance improvement when operation is performed on the
dimwith size smaller or equal to256andstridevalue at thatdimof both input, output and indices tensor must equal to 1. For that reason,IsApplicableconstraint makes sure that the operation only works with the above case.
float16
| op_name | dtype | size | dim | contiguous | model | direction | ROCm pytorch | MIOpen HIP | Improvement |
|---|---|---|---|---|---|---|---|---|---|
| CumMax | float16 | [512 64 112 112] | -1 | TRUE | random | fwd | 79103622 | 10290800 | 7.69 |
| CumMax | float16 | [512 64 56 56] | -1 | TRUE | random | fwd | 39319091 | 2490330 | 15.79 |
| CumMax | float16 | [512 128 56 56] | -1 | TRUE | random | fwd | 78599721 | 4982140 | 15.78 |
| CumMax | float16 | [512 128 28 28] | -1 | TRUE | random | fwd | 39227767 | 2479240 | 15.82 |
| CumMax | float16 | [512 256 28 28] | -1 | TRUE | random | fwd | 78414528 | 4955720 | 15.82 |
| CumMax | float16 | [512 256 14 14] | -1 | TRUE | random | fwd | 39164283 | 2427920 | 16.13 |
| CumMax | float16 | [512 512 14 14] | -1 | TRUE | random | fwd | 80268168 | 4854160 | 16.54 |
| CumMax | float16 | [512 512 7 7] | -1 | TRUE | random | fwd | 39191305 | 2401980 | 16.32 |
| CumMax | float16 | [512 1024 7 7] | -1 | TRUE | random | fwd | 78271414 | 4805250 | 16.29 |
| CumMax | float16 | [512 1024 100] | -1 | TRUE | random | fwd | 11277661 | 1463220 | 7.71 |
| CumMax | float16 | [1024 1024 7 7] | -1 | TRUE | random | fwd | 156666821 | 10686200 | 14.66 |
| CumMax | float16 | [1024 1024 100] | -1 | TRUE | random | fwd | 22540060 | 2920460 | 7.72 |
| CumMin | float16 | [512 64 112 112] | -1 | TRUE | random | fwd | 79032894 | 10293300 | 7.68 |
| CumMin | float16 | [512 64 56 56] | -1 | TRUE | random | fwd | 39290595 | 2491030 | 15.77 |
| CumMin | float16 | [512 128 56 56] | -1 | TRUE | random | fwd | 78578550 | 4982730 | 15.77 |
| CumMin | float16 | [512 128 28 28] | -1 | TRUE | random | fwd | 39189412 | 2478940 | 15.81 |
| CumMin | float16 | [512 256 28 28] | -1 | TRUE | random | fwd | 78419674 | 4956120 | 15.82 |
| CumMin | float16 | [512 256 14 14] | -1 | TRUE | random | fwd | 39156197 | 2426850 | 16.13 |
| CumMin | float16 | [512 512 14 14] | -1 | TRUE | random | fwd | 78311994 | 4855330 | 16.13 |
| CumMin | float16 | [512 512 7 7] | -1 | TRUE | random | fwd | 39105638 | 2400610 | 16.29 |
| CumMin | float16 | [512 1024 7 7] | -1 | TRUE | random | fwd | 78254683 | 4805610 | 16.28 |
| CumMin | float16 | [512 1024 100] | -1 | TRUE | random | fwd | 11269600 | 1461600 | 7.71 |
| CumMin | float16 | [1024 1024 7 7] | -1 | TRUE | random | fwd | 156521111 | 10696300 | 14.63 |
| CumMin | float16 | [1024 1024 100] | -1 | TRUE | random | fwd | 22551889 | 5641600 | 4.00 |
| CumSum | float16 | [512 64 112 112] | -1 | TRUE | random | fwd | 36839240 | 6739680 | 5.47 |
| CumSum | float16 | [512 64 56 56] | -1 | TRUE | random | fwd | 18283694 | 2321070 | 7.88 |
| CumSum | float16 | [512 128 56 56] | -1 | TRUE | random | fwd | 36585132 | 4639960 | 7.88 |
| CumSum | float16 | [512 128 28 28] | -1 | TRUE | random | fwd | 18230703 | 2307310 | 7.90 |
| CumSum | float16 | [512 256 28 28] | -1 | TRUE | random | fwd | 36477501 | 4612030 | 7.91 |
| CumSum | float16 | [512 256 14 14] | -1 | TRUE | random | fwd | 18207967 | 2298780 | 7.92 |
| CumSum | float16 | [512 512 14 14] | -1 | TRUE | random | fwd | 36433086 | 4594060 | 7.93 |
| CumSum | float16 | [512 512 7 7] | -1 | TRUE | random | fwd | 18215727 | 2291620 | 7.95 |
| CumSum | float16 | [512 1024 7 7] | -1 | TRUE | random | fwd | 36442782 | 4580620 | 7.96 |
| CumSum | float16 | [512 1024 100] | -1 | TRUE | random | fwd | 5255286 | 956699 | 5.49 |
| CumSum | float16 | [1024 1024 7 7] | -1 | TRUE | random | fwd | 72925500 | 9161610 | 7.96 |
| CumSum | float16 | [1024 1024 100] | -1 | TRUE | random | fwd | 10510668 | 1924150 | 5.46 |
| CumProd | float16 | [512 64 112 112] | -1 | TRUE | random | fwd | 36853144 | 6734100 | 5.47 |
| CumProd | float16 | [512 64 56 56] | -1 | TRUE | random | fwd | 18310781 | 2320740 | 7.89 |
| CumProd | float16 | [512 128 56 56] | -1 | TRUE | random | fwd | 36623723 | 4640240 | 7.89 |
| CumProd | float16 | [512 128 28 28] | -1 | TRUE | random | fwd | 18271694 | 2309390 | 7.91 |
| CumProd | float16 | [512 256 28 28] | -1 | TRUE | random | fwd | 36523629 | 4616960 | 7.91 |
| CumProd | float16 | [512 256 14 14] | -1 | TRUE | random | fwd | 18221247 | 2301290 | 7.92 |
| CumProd | float16 | [512 512 14 14] | -1 | TRUE | random | fwd | 36498109 | 4601770 | 7.93 |
| CumProd | float16 | [512 512 7 7] | -1 | TRUE | random | fwd | 18246623 | 2295060 | 7.95 |
| CumProd | float16 | [512 1024 7 7] | -1 | TRUE | random | fwd | 39393701 | 4588880 | 8.58 |
| CumProd | float16 | [512 1024 100] | -1 | TRUE | random | fwd | 5260982 | 956966 | 5.50 |
| CumProd | float16 | [1024 1024 7 7] | -1 | TRUE | random | fwd | 73033611 | 10276700 | 7.11 |
| CumProd | float16 | [1024 1024 100] | -1 | TRUE | random | fwd | 10518988 | 1910770 | 5.51 |
float32
| op_name | dtype | size | dim | contiguous | model | direction | ROCm pytorch | MIOpen HIP | Improvement |
|---|---|---|---|---|---|---|---|---|---|
| CumMax | float32 | [512 64 112 112] | -1 | TRUE | random | fwd | 79353556 | 10510300 | 7.55 |
| CumMax | float32 | [512 64 56 56] | -1 | TRUE | random | fwd | 39444502 | 2528340 | 15.60 |
| CumMax | float32 | [512 128 56 56] | -1 | TRUE | random | fwd | 78924619 | 5057250 | 15.61 |
| CumMax | float32 | [512 128 28 28] | -1 | TRUE | random | fwd | 39394950 | 2517180 | 15.65 |
| CumMax | float32 | [512 256 28 28] | -1 | TRUE | random | fwd | 78769181 | 5027420 | 15.67 |
| CumMax | float32 | [512 256 14 14] | -1 | TRUE | random | fwd | 39279320 | 2443630 | 16.07 |
| CumMax | float32 | [512 512 14 14] | -1 | TRUE | random | fwd | 80072059 | 4885720 | 16.39 |
| CumMax | float32 | [512 512 7 7] | -1 | TRUE | random | fwd | 39238905 | 2423280 | 16.19 |
| CumMax | float32 | [512 1024 7 7] | -1 | TRUE | random | fwd | 78490449 | 4814360 | 16.30 |
| CumMax | float32 | [512 1024 100] | -1 | TRUE | random | fwd | 11320257 | 1490630 | 7.59 |
| CumMax | float32 | [1024 1024 7 7] | -1 | TRUE | random | fwd | 157026754 | 9636110 | 16.30 |
| CumMax | float32 | [1024 1024 100] | -1 | TRUE | random | fwd | 22649554 | 2982970 | 7.59 |
| CumMin | float32 | [512 64 112 112] | -1 | TRUE | random | fwd | 79317382 | 10511900 | 7.55 |
| CumMin | float32 | [512 64 56 56] | -1 | TRUE | random | fwd | 39419030 | 2529820 | 15.58 |
| CumMin | float32 | [512 128 56 56] | -1 | TRUE | random | fwd | 78850445 | 9170440 | 8.60 |
| CumMin | float32 | [512 128 28 28] | -1 | TRUE | random | fwd | 39393495 | 2515360 | 15.66 |
| CumMin | float32 | [512 256 28 28] | -1 | TRUE | random | fwd | 78737166 | 5027230 | 15.66 |
| CumMin | float32 | [512 256 14 14] | -1 | TRUE | random | fwd | 39270408 | 2443900 | 16.07 |
| CumMin | float32 | [512 512 14 14] | -1 | TRUE | random | fwd | 79465092 | 4885980 | 16.26 |
| CumMin | float32 | [512 512 7 7] | -1 | TRUE | random | fwd | 39264681 | 2410990 | 16.29 |
| CumMin | float32 | [512 1024 7 7] | -1 | TRUE | random | fwd | 78513042 | 4815270 | 16.31 |
| CumMin | float32 | [512 1024 100] | -1 | TRUE | random | fwd | 11321649 | 1490860 | 7.59 |
| CumMin | float32 | [1024 1024 7 7] | -1 | TRUE | random | fwd | 157041875 | 9633820 | 16.30 |
| CumMin | float32 | [1024 1024 100] | -1 | TRUE | random | fwd | 22661778 | 5730720 | 3.95 |
| CumSum | float32 | [512 64 112 112] | -1 | TRUE | random | fwd | 37420899 | 7051980 | 5.31 |
| CumSum | float32 | [512 64 56 56] | -1 | TRUE | random | fwd | 18553115 | 2330090 | 7.96 |
| CumSum | float32 | [512 128 56 56] | -1 | TRUE | random | fwd | 37096775 | 4656530 | 7.97 |
| CumSum | float32 | [512 128 28 28] | -1 | TRUE | random | fwd | 18498636 | 2312900 | 8.00 |
| CumSum | float32 | [512 256 28 28] | -1 | TRUE | random | fwd | 37008008 | 4623340 | 8.00 |
| CumSum | float32 | [512 256 14 14] | -1 | TRUE | random | fwd | 18427773 | 2301890 | 8.01 |
| CumSum | float32 | [512 512 14 14] | -1 | TRUE | random | fwd | 36886474 | 4601850 | 8.02 |
| CumSum | float32 | [512 512 7 7] | -1 | TRUE | random | fwd | 18399326 | 2293910 | 8.02 |
| CumSum | float32 | [512 1024 7 7] | -1 | TRUE | random | fwd | 36863786 | 4586130 | 8.04 |
| CumSum | float32 | [512 1024 100] | -1 | TRUE | random | fwd | 5337701 | 998352 | 5.35 |
| CumSum | float32 | [1024 1024 7 7] | -1 | TRUE | random | fwd | 75153089 | 9171500 | 8.19 |
| CumSum | float32 | [1024 1024 100] | -1 | TRUE | random | fwd | 10686874 | 1993500 | 5.36 |
| CumProd | float32 | [512 64 112 112] | -1 | TRUE | random | fwd | 37492178 | 7043960 | 5.32 |
| CumProd | float32 | [512 64 56 56] | -1 | TRUE | random | fwd | 18602251 | 2328130 | 7.99 |
| CumProd | float32 | [512 128 56 56] | -1 | TRUE | random | fwd | 37180790 | 4653930 | 7.99 |
| CumProd | float32 | [512 128 28 28] | -1 | TRUE | random | fwd | 18552732 | 2312610 | 8.02 |
| CumProd | float32 | [512 256 28 28] | -1 | TRUE | random | fwd | 37102295 | 4625170 | 8.02 |
| CumProd | float32 | [512 256 14 14] | -1 | TRUE | random | fwd | 18471901 | 2303490 | 8.02 |
| CumProd | float32 | [512 512 14 14] | -1 | TRUE | random | fwd | 36980297 | 4605450 | 8.03 |
| CumProd | float32 | [512 512 7 7] | -1 | TRUE | random | fwd | 18449117 | 2295490 | 8.04 |
| CumProd | float32 | [512 1024 7 7] | -1 | TRUE | random | fwd | 36929706 | 4589030 | 8.05 |
| CumProd | float32 | [512 1024 100] | -1 | TRUE | random | fwd | 5350325 | 996876 | 5.37 |
| CumProd | float32 | [1024 1024 7 7] | -1 | TRUE | random | fwd | 73828228 | 9180310 | 8.04 |
| CumProd | float32 | [1024 1024 100] | -1 | TRUE | random | fwd | 10692522 | 1992210 | 5.37 |
bfloat16
| op_name | dtype | size | dim | contiguous | model | direction | ROCm pytorch | MIOpen HIP | Improvement |
|---|---|---|---|---|---|---|---|---|---|
| CumMax | bfloat16 | [512 64 112 112] | -1 | TRUE | random | fwd | 82001795 | 10583800 | 7.75 |
| CumMax | bfloat16 | [512 64 56 56] | -1 | TRUE | random | fwd | 40779253 | 2538400 | 16.06 |
| CumMax | bfloat16 | [512 128 56 56] | -1 | TRUE | random | fwd | 81604184 | 5080060 | 16.06 |
| CumMax | bfloat16 | [512 128 28 28] | -1 | TRUE | random | fwd | 40773765 | 2517520 | 16.20 |
| CumMax | bfloat16 | [512 256 28 28] | -1 | TRUE | random | fwd | 81536393 | 5032490 | 16.20 |
| CumMax | bfloat16 | [512 256 14 14] | -1 | TRUE | random | fwd | 40759269 | 2462850 | 16.55 |
| CumMax | bfloat16 | [512 512 14 14] | -1 | TRUE | random | fwd | 81497354 | 4925450 | 16.55 |
| CumMax | bfloat16 | [512 512 7 7] | -1 | TRUE | random | fwd | 41580409 | 2433980 | 17.08 |
| CumMax | bfloat16 | [512 1024 7 7] | -1 | TRUE | random | fwd | 81399196 | 4872310 | 16.71 |
| CumMax | bfloat16 | [512 1024 100] | -1 | TRUE | random | fwd | 11702748 | 1502110 | 7.79 |
| CumMax | bfloat16 | [1024 1024 7 7] | -1 | TRUE | random | fwd | 162846550 | 9719290 | 16.75 |
| CumMax | bfloat16 | [1024 1024 100] | -1 | TRUE | random | fwd | 23391864 | 2999200 | 7.80 |
| CumMin | bfloat16 | [512 64 112 112] | -1 | TRUE | random | fwd | 82027507 | 10510000 | 7.80 |
| CumMin | bfloat16 | [512 64 56 56] | -1 | TRUE | random | fwd | 40770069 | 2529640 | 16.12 |
| CumMin | bfloat16 | [512 128 56 56] | -1 | TRUE | random | fwd | 81606825 | 5061550 | 16.12 |
| CumMin | bfloat16 | [512 128 28 28] | -1 | TRUE | random | fwd | 40762245 | 2513820 | 16.22 |
| CumMin | bfloat16 | [512 256 28 28] | -1 | TRUE | random | fwd | 81501883 | 5028670 | 16.21 |
| CumMin | bfloat16 | [512 256 14 14] | -1 | TRUE | random | fwd | 40744486 | 2462740 | 16.54 |
| CumMin | bfloat16 | [512 512 14 14] | -1 | TRUE | random | fwd | 81500475 | 4921650 | 16.56 |
| CumMin | bfloat16 | [512 512 7 7] | -1 | TRUE | random | fwd | 40697830 | 2433980 | 16.72 |
| CumMin | bfloat16 | [512 1024 7 7] | -1 | TRUE | random | fwd | 81402956 | 4870530 | 16.71 |
| CumMin | bfloat16 | [512 1024 100] | -1 | TRUE | random | fwd | 11700876 | 1492480 | 7.84 |
| CumMin | bfloat16 | [1024 1024 7 7] | -1 | TRUE | random | fwd | 162799578 | 9718710 | 16.75 |
| CumMin | bfloat16 | [1024 1024 100] | -1 | TRUE | random | fwd | 23387721 | 2980940 | 7.85 |
| CumSum | bfloat16 | [512 64 112 112] | -1 | TRUE | random | fwd | 46814849 | 6889390 | 6.80 |
| CumSum | bfloat16 | [512 64 56 56] | -1 | TRUE | random | fwd | 23282362 | 2320860 | 10.03 |
| CumSum | bfloat16 | [512 128 56 56] | -1 | TRUE | random | fwd | 46555589 | 6526530 | 7.13 |
| CumSum | bfloat16 | [512 128 28 28] | -1 | TRUE | random | fwd | 23230827 | 2307880 | 10.07 |
| CumSum | bfloat16 | [512 256 28 28] | -1 | TRUE | random | fwd | 46477910 | 4613740 | 10.07 |
| CumSum | bfloat16 | [512 256 14 14] | -1 | TRUE | random | fwd | 23206284 | 2299330 | 10.09 |
| CumSum | bfloat16 | [512 512 14 14] | -1 | TRUE | random | fwd | 46414775 | 4595450 | 10.10 |
| CumSum | bfloat16 | [512 512 7 7] | -1 | TRUE | random | fwd | 23198811 | 2292080 | 10.12 |
| CumSum | bfloat16 | [512 1024 7 7] | -1 | TRUE | random | fwd | 46428582 | 4581610 | 10.13 |
| CumSum | bfloat16 | [512 1024 100] | -1 | TRUE | random | fwd | 6686083 | 978317 | 6.83 |
| CumSum | bfloat16 | [1024 1024 7 7] | -1 | TRUE | random | fwd | 92810670 | 9164860 | 10.13 |
| CumSum | bfloat16 | [1024 1024 100] | -1 | TRUE | random | fwd | 13372485 | 1954060 | 6.84 |
| CumProd | bfloat16 | [512 64 112 112] | -1 | TRUE | random | fwd | 46990639 | 6894460 | 6.82 |
| CumProd | bfloat16 | [512 64 56 56] | -1 | TRUE | random | fwd | 23371545 | 2323120 | 10.06 |
| CumProd | bfloat16 | [512 128 56 56] | -1 | TRUE | random | fwd | 46773218 | 4644510 | 10.07 |
| CumProd | bfloat16 | [512 128 28 28] | -1 | TRUE | random | fwd | 23333338 | 2310600 | 10.10 |
| CumProd | bfloat16 | [512 256 28 28] | -1 | TRUE | random | fwd | 46674003 | 4619370 | 10.10 |
| CumProd | bfloat16 | [512 256 14 14] | -1 | TRUE | random | fwd | 23306058 | 2302690 | 10.12 |
| CumProd | bfloat16 | [512 512 14 14] | -1 | TRUE | random | fwd | 46625844 | 4603780 | 10.13 |
| CumProd | bfloat16 | [512 512 7 7] | -1 | TRUE | random | fwd | 23304010 | 2295150 | 10.15 |
| CumProd | bfloat16 | [512 1024 7 7] | -1 | TRUE | random | fwd | 46605092 | 4588370 | 10.16 |
| CumProd | bfloat16 | [512 1024 100] | -1 | TRUE | random | fwd | 6709842 | 979597 | 6.85 |
| CumProd | bfloat16 | [1024 1024 7 7] | -1 | TRUE | random | fwd | 93215385 | 9178690 | 10.16 |
| CumProd | bfloat16 | [1024 1024 100] | -1 | TRUE | random | fwd | 13421396 | 1954820 | 6.87 |
Average over all cases:
| type | average |
|---|---|
| float16 | 10.42 |
| float32 | 10.43 |
| bfloat16 | 11.32 |
This PR is a continuation of PR #3182. Accidently, I have closed the older PR, and then made change to the working branch, which makes me cannot reopen the older once. There are not many comments in that PR, I hope it doesn't interrupt your reviewing process.
And again, sorry for this Inconvenience.
MIOpen is moving to the new monorepo setup and all older unmerged PR's are being closed. Please re-open this as part of the new repo if these changes are still needed.