MIOpen
MIOpen copied to clipboard
Implement Cumulative reduction (max, min, sum, prod) forward with small last dim
- Added cumulative reduction forward operation and kernel with solver, support binary operators (max, min, sum, prod). This operation equivalent to cummax, cummin, cumsum, cumprod in Pytorch.
- Added driver test and gtest for cumulative reduction.
- New API is guarded by MIOPEN_BETA_API macro.
- Compared to ROCm pytorch, there is a performance improvement when operation is performed on the
dimwith size smaller or equal to256andstridevalue at thatdimof both input, output and indices tensor must equal to 1. For that reason,IsApplicableconstraint makes sure that the operation only works with the above case.
float16
| op_name | dtype | size | dim | contiguous | model | direction | ROCm pytorch | MIOpen HIP | Improvement |
|---|---|---|---|---|---|---|---|---|---|
| CumMax | float16 | [512 64 112 112] | -1 | TRUE | random | fwd | 79103622 | 10290800 | 7.69 |
| CumMax | float16 | [512 64 56 56] | -1 | TRUE | random | fwd | 39319091 | 2490330 | 15.79 |
| CumMax | float16 | [512 128 56 56] | -1 | TRUE | random | fwd | 78599721 | 4982140 | 15.78 |
| CumMax | float16 | [512 128 28 28] | -1 | TRUE | random | fwd | 39227767 | 2479240 | 15.82 |
| CumMax | float16 | [512 256 28 28] | -1 | TRUE | random | fwd | 78414528 | 4955720 | 15.82 |
| CumMax | float16 | [512 256 14 14] | -1 | TRUE | random | fwd | 39164283 | 2427920 | 16.13 |
| CumMax | float16 | [512 512 14 14] | -1 | TRUE | random | fwd | 80268168 | 4854160 | 16.54 |
| CumMax | float16 | [512 512 7 7] | -1 | TRUE | random | fwd | 39191305 | 2401980 | 16.32 |
| CumMax | float16 | [512 1024 7 7] | -1 | TRUE | random | fwd | 78271414 | 4805250 | 16.29 |
| CumMax | float16 | [512 1024 100] | -1 | TRUE | random | fwd | 11277661 | 1463220 | 7.71 |
| CumMax | float16 | [1024 1024 7 7] | -1 | TRUE | random | fwd | 156666821 | 10686200 | 14.66 |
| CumMax | float16 | [1024 1024 100] | -1 | TRUE | random | fwd | 22540060 | 2920460 | 7.72 |
| CumMin | float16 | [512 64 112 112] | -1 | TRUE | random | fwd | 79032894 | 10293300 | 7.68 |
| CumMin | float16 | [512 64 56 56] | -1 | TRUE | random | fwd | 39290595 | 2491030 | 15.77 |
| CumMin | float16 | [512 128 56 56] | -1 | TRUE | random | fwd | 78578550 | 4982730 | 15.77 |
| CumMin | float16 | [512 128 28 28] | -1 | TRUE | random | fwd | 39189412 | 2478940 | 15.81 |
| CumMin | float16 | [512 256 28 28] | -1 | TRUE | random | fwd | 78419674 | 4956120 | 15.82 |
| CumMin | float16 | [512 256 14 14] | -1 | TRUE | random | fwd | 39156197 | 2426850 | 16.13 |
| CumMin | float16 | [512 512 14 14] | -1 | TRUE | random | fwd | 78311994 | 4855330 | 16.13 |
| CumMin | float16 | [512 512 7 7] | -1 | TRUE | random | fwd | 39105638 | 2400610 | 16.29 |
| CumMin | float16 | [512 1024 7 7] | -1 | TRUE | random | fwd | 78254683 | 4805610 | 16.28 |
| CumMin | float16 | [512 1024 100] | -1 | TRUE | random | fwd | 11269600 | 1461600 | 7.71 |
| CumMin | float16 | [1024 1024 7 7] | -1 | TRUE | random | fwd | 156521111 | 10696300 | 14.63 |
| CumMin | float16 | [1024 1024 100] | -1 | TRUE | random | fwd | 22551889 | 5641600 | 4.00 |
| CumSum | float16 | [512 64 112 112] | -1 | TRUE | random | fwd | 36839240 | 6739680 | 5.47 |
| CumSum | float16 | [512 64 56 56] | -1 | TRUE | random | fwd | 18283694 | 2321070 | 7.88 |
| CumSum | float16 | [512 128 56 56] | -1 | TRUE | random | fwd | 36585132 | 4639960 | 7.88 |
| CumSum | float16 | [512 128 28 28] | -1 | TRUE | random | fwd | 18230703 | 2307310 | 7.90 |
| CumSum | float16 | [512 256 28 28] | -1 | TRUE | random | fwd | 36477501 | 4612030 | 7.91 |
| CumSum | float16 | [512 256 14 14] | -1 | TRUE | random | fwd | 18207967 | 2298780 | 7.92 |
| CumSum | float16 | [512 512 14 14] | -1 | TRUE | random | fwd | 36433086 | 4594060 | 7.93 |
| CumSum | float16 | [512 512 7 7] | -1 | TRUE | random | fwd | 18215727 | 2291620 | 7.95 |
| CumSum | float16 | [512 1024 7 7] | -1 | TRUE | random | fwd | 36442782 | 4580620 | 7.96 |
| CumSum | float16 | [512 1024 100] | -1 | TRUE | random | fwd | 5255286 | 956699 | 5.49 |
| CumSum | float16 | [1024 1024 7 7] | -1 | TRUE | random | fwd | 72925500 | 9161610 | 7.96 |
| CumSum | float16 | [1024 1024 100] | -1 | TRUE | random | fwd | 10510668 | 1924150 | 5.46 |
| CumProd | float16 | [512 64 112 112] | -1 | TRUE | random | fwd | 36853144 | 6734100 | 5.47 |
| CumProd | float16 | [512 64 56 56] | -1 | TRUE | random | fwd | 18310781 | 2320740 | 7.89 |
| CumProd | float16 | [512 128 56 56] | -1 | TRUE | random | fwd | 36623723 | 4640240 | 7.89 |
| CumProd | float16 | [512 128 28 28] | -1 | TRUE | random | fwd | 18271694 | 2309390 | 7.91 |
| CumProd | float16 | [512 256 28 28] | -1 | TRUE | random | fwd | 36523629 | 4616960 | 7.91 |
| CumProd | float16 | [512 256 14 14] | -1 | TRUE | random | fwd | 18221247 | 2301290 | 7.92 |
| CumProd | float16 | [512 512 14 14] | -1 | TRUE | random | fwd | 36498109 | 4601770 | 7.93 |
| CumProd | float16 | [512 512 7 7] | -1 | TRUE | random | fwd | 18246623 | 2295060 | 7.95 |
| CumProd | float16 | [512 1024 7 7] | -1 | TRUE | random | fwd | 39393701 | 4588880 | 8.58 |
| CumProd | float16 | [512 1024 100] | -1 | TRUE | random | fwd | 5260982 | 956966 | 5.50 |
| CumProd | float16 | [1024 1024 7 7] | -1 | TRUE | random | fwd | 73033611 | 10276700 | 7.11 |
| CumProd | float16 | [1024 1024 100] | -1 | TRUE | random | fwd | 10518988 | 1910770 | 5.51 |
float32
| op_name | dtype | size | dim | contiguous | model | direction | ROCm pytorch | MIOpen HIP | Improvement |
|---|---|---|---|---|---|---|---|---|---|
| CumMax | float32 | [512 64 112 112] | -1 | TRUE | random | fwd | 79353556 | 10510300 | 7.55 |
| CumMax | float32 | [512 64 56 56] | -1 | TRUE | random | fwd | 39444502 | 2528340 | 15.60 |
| CumMax | float32 | [512 128 56 56] | -1 | TRUE | random | fwd | 78924619 | 5057250 | 15.61 |
| CumMax | float32 | [512 128 28 28] | -1 | TRUE | random | fwd | 39394950 | 2517180 | 15.65 |
| CumMax | float32 | [512 256 28 28] | -1 | TRUE | random | fwd | 78769181 | 5027420 | 15.67 |
| CumMax | float32 | [512 256 14 14] | -1 | TRUE | random | fwd | 39279320 | 2443630 | 16.07 |
| CumMax | float32 | [512 512 14 14] | -1 | TRUE | random | fwd | 80072059 | 4885720 | 16.39 |
| CumMax | float32 | [512 512 7 7] | -1 | TRUE | random | fwd | 39238905 | 2423280 | 16.19 |
| CumMax | float32 | [512 1024 7 7] | -1 | TRUE | random | fwd | 78490449 | 4814360 | 16.30 |
| CumMax | float32 | [512 1024 100] | -1 | TRUE | random | fwd | 11320257 | 1490630 | 7.59 |
| CumMax | float32 | [1024 1024 7 7] | -1 | TRUE | random | fwd | 157026754 | 9636110 | 16.30 |
| CumMax | float32 | [1024 1024 100] | -1 | TRUE | random | fwd | 22649554 | 2982970 | 7.59 |
| CumMin | float32 | [512 64 112 112] | -1 | TRUE | random | fwd | 79317382 | 10511900 | 7.55 |
| CumMin | float32 | [512 64 56 56] | -1 | TRUE | random | fwd | 39419030 | 2529820 | 15.58 |
| CumMin | float32 | [512 128 56 56] | -1 | TRUE | random | fwd | 78850445 | 9170440 | 8.60 |
| CumMin | float32 | [512 128 28 28] | -1 | TRUE | random | fwd | 39393495 | 2515360 | 15.66 |
| CumMin | float32 | [512 256 28 28] | -1 | TRUE | random | fwd | 78737166 | 5027230 | 15.66 |
| CumMin | float32 | [512 256 14 14] | -1 | TRUE | random | fwd | 39270408 | 2443900 | 16.07 |
| CumMin | float32 | [512 512 14 14] | -1 | TRUE | random | fwd | 79465092 | 4885980 | 16.26 |
| CumMin | float32 | [512 512 7 7] | -1 | TRUE | random | fwd | 39264681 | 2410990 | 16.29 |
| CumMin | float32 | [512 1024 7 7] | -1 | TRUE | random | fwd | 78513042 | 4815270 | 16.31 |
| CumMin | float32 | [512 1024 100] | -1 | TRUE | random | fwd | 11321649 | 1490860 | 7.59 |
| CumMin | float32 | [1024 1024 7 7] | -1 | TRUE | random | fwd | 157041875 | 9633820 | 16.30 |
| CumMin | float32 | [1024 1024 100] | -1 | TRUE | random | fwd | 22661778 | 5730720 | 3.95 |
| CumSum | float32 | [512 64 112 112] | -1 | TRUE | random | fwd | 37420899 | 7051980 | 5.31 |
| CumSum | float32 | [512 64 56 56] | -1 | TRUE | random | fwd | 18553115 | 2330090 | 7.96 |
| CumSum | float32 | [512 128 56 56] | -1 | TRUE | random | fwd | 37096775 | 4656530 | 7.97 |
| CumSum | float32 | [512 128 28 28] | -1 | TRUE | random | fwd | 18498636 | 2312900 | 8.00 |
| CumSum | float32 | [512 256 28 28] | -1 | TRUE | random | fwd | 37008008 | 4623340 | 8.00 |
| CumSum | float32 | [512 256 14 14] | -1 | TRUE | random | fwd | 18427773 | 2301890 | 8.01 |
| CumSum | float32 | [512 512 14 14] | -1 | TRUE | random | fwd | 36886474 | 4601850 | 8.02 |
| CumSum | float32 | [512 512 7 7] | -1 | TRUE | random | fwd | 18399326 | 2293910 | 8.02 |
| CumSum | float32 | [512 1024 7 7] | -1 | TRUE | random | fwd | 36863786 | 4586130 | 8.04 |
| CumSum | float32 | [512 1024 100] | -1 | TRUE | random | fwd | 5337701 | 998352 | 5.35 |
| CumSum | float32 | [1024 1024 7 7] | -1 | TRUE | random | fwd | 75153089 | 9171500 | 8.19 |
| CumSum | float32 | [1024 1024 100] | -1 | TRUE | random | fwd | 10686874 | 1993500 | 5.36 |
| CumProd | float32 | [512 64 112 112] | -1 | TRUE | random | fwd | 37492178 | 7043960 | 5.32 |
| CumProd | float32 | [512 64 56 56] | -1 | TRUE | random | fwd | 18602251 | 2328130 | 7.99 |
| CumProd | float32 | [512 128 56 56] | -1 | TRUE | random | fwd | 37180790 | 4653930 | 7.99 |
| CumProd | float32 | [512 128 28 28] | -1 | TRUE | random | fwd | 18552732 | 2312610 | 8.02 |
| CumProd | float32 | [512 256 28 28] | -1 | TRUE | random | fwd | 37102295 | 4625170 | 8.02 |
| CumProd | float32 | [512 256 14 14] | -1 | TRUE | random | fwd | 18471901 | 2303490 | 8.02 |
| CumProd | float32 | [512 512 14 14] | -1 | TRUE | random | fwd | 36980297 | 4605450 | 8.03 |
| CumProd | float32 | [512 512 7 7] | -1 | TRUE | random | fwd | 18449117 | 2295490 | 8.04 |
| CumProd | float32 | [512 1024 7 7] | -1 | TRUE | random | fwd | 36929706 | 4589030 | 8.05 |
| CumProd | float32 | [512 1024 100] | -1 | TRUE | random | fwd | 5350325 | 996876 | 5.37 |
| CumProd | float32 | [1024 1024 7 7] | -1 | TRUE | random | fwd | 73828228 | 9180310 | 8.04 |
| CumProd | float32 | [1024 1024 100] | -1 | TRUE | random | fwd | 10692522 | 1992210 | 5.37 |
bfloat16
| op_name | dtype | size | dim | contiguous | model | direction | ROCm pytorch | MIOpen HIP | Improvement |
|---|---|---|---|---|---|---|---|---|---|
| CumMax | bfloat16 | [512 64 112 112] | -1 | TRUE | random | fwd | 82001795 | 10583800 | 7.75 |
| CumMax | bfloat16 | [512 64 56 56] | -1 | TRUE | random | fwd | 40779253 | 2538400 | 16.06 |
| CumMax | bfloat16 | [512 128 56 56] | -1 | TRUE | random | fwd | 81604184 | 5080060 | 16.06 |
| CumMax | bfloat16 | [512 128 28 28] | -1 | TRUE | random | fwd | 40773765 | 2517520 | 16.20 |
| CumMax | bfloat16 | [512 256 28 28] | -1 | TRUE | random | fwd | 81536393 | 5032490 | 16.20 |
| CumMax | bfloat16 | [512 256 14 14] | -1 | TRUE | random | fwd | 40759269 | 2462850 | 16.55 |
| CumMax | bfloat16 | [512 512 14 14] | -1 | TRUE | random | fwd | 81497354 | 4925450 | 16.55 |
| CumMax | bfloat16 | [512 512 7 7] | -1 | TRUE | random | fwd | 41580409 | 2433980 | 17.08 |
| CumMax | bfloat16 | [512 1024 7 7] | -1 | TRUE | random | fwd | 81399196 | 4872310 | 16.71 |
| CumMax | bfloat16 | [512 1024 100] | -1 | TRUE | random | fwd | 11702748 | 1502110 | 7.79 |
| CumMax | bfloat16 | [1024 1024 7 7] | -1 | TRUE | random | fwd | 162846550 | 9719290 | 16.75 |
| CumMax | bfloat16 | [1024 1024 100] | -1 | TRUE | random | fwd | 23391864 | 2999200 | 7.80 |
| CumMin | bfloat16 | [512 64 112 112] | -1 | TRUE | random | fwd | 82027507 | 10510000 | 7.80 |
| CumMin | bfloat16 | [512 64 56 56] | -1 | TRUE | random | fwd | 40770069 | 2529640 | 16.12 |
| CumMin | bfloat16 | [512 128 56 56] | -1 | TRUE | random | fwd | 81606825 | 5061550 | 16.12 |
| CumMin | bfloat16 | [512 128 28 28] | -1 | TRUE | random | fwd | 40762245 | 2513820 | 16.22 |
| CumMin | bfloat16 | [512 256 28 28] | -1 | TRUE | random | fwd | 81501883 | 5028670 | 16.21 |
| CumMin | bfloat16 | [512 256 14 14] | -1 | TRUE | random | fwd | 40744486 | 2462740 | 16.54 |
| CumMin | bfloat16 | [512 512 14 14] | -1 | TRUE | random | fwd | 81500475 | 4921650 | 16.56 |
| CumMin | bfloat16 | [512 512 7 7] | -1 | TRUE | random | fwd | 40697830 | 2433980 | 16.72 |
| CumMin | bfloat16 | [512 1024 7 7] | -1 | TRUE | random | fwd | 81402956 | 4870530 | 16.71 |
| CumMin | bfloat16 | [512 1024 100] | -1 | TRUE | random | fwd | 11700876 | 1492480 | 7.84 |
| CumMin | bfloat16 | [1024 1024 7 7] | -1 | TRUE | random | fwd | 162799578 | 9718710 | 16.75 |
| CumMin | bfloat16 | [1024 1024 100] | -1 | TRUE | random | fwd | 23387721 | 2980940 | 7.85 |
| CumSum | bfloat16 | [512 64 112 112] | -1 | TRUE | random | fwd | 46814849 | 6889390 | 6.80 |
| CumSum | bfloat16 | [512 64 56 56] | -1 | TRUE | random | fwd | 23282362 | 2320860 | 10.03 |
| CumSum | bfloat16 | [512 128 56 56] | -1 | TRUE | random | fwd | 46555589 | 6526530 | 7.13 |
| CumSum | bfloat16 | [512 128 28 28] | -1 | TRUE | random | fwd | 23230827 | 2307880 | 10.07 |
| CumSum | bfloat16 | [512 256 28 28] | -1 | TRUE | random | fwd | 46477910 | 4613740 | 10.07 |
| CumSum | bfloat16 | [512 256 14 14] | -1 | TRUE | random | fwd | 23206284 | 2299330 | 10.09 |
| CumSum | bfloat16 | [512 512 14 14] | -1 | TRUE | random | fwd | 46414775 | 4595450 | 10.10 |
| CumSum | bfloat16 | [512 512 7 7] | -1 | TRUE | random | fwd | 23198811 | 2292080 | 10.12 |
| CumSum | bfloat16 | [512 1024 7 7] | -1 | TRUE | random | fwd | 46428582 | 4581610 | 10.13 |
| CumSum | bfloat16 | [512 1024 100] | -1 | TRUE | random | fwd | 6686083 | 978317 | 6.83 |
| CumSum | bfloat16 | [1024 1024 7 7] | -1 | TRUE | random | fwd | 92810670 | 9164860 | 10.13 |
| CumSum | bfloat16 | [1024 1024 100] | -1 | TRUE | random | fwd | 13372485 | 1954060 | 6.84 |
| CumProd | bfloat16 | [512 64 112 112] | -1 | TRUE | random | fwd | 46990639 | 6894460 | 6.82 |
| CumProd | bfloat16 | [512 64 56 56] | -1 | TRUE | random | fwd | 23371545 | 2323120 | 10.06 |
| CumProd | bfloat16 | [512 128 56 56] | -1 | TRUE | random | fwd | 46773218 | 4644510 | 10.07 |
| CumProd | bfloat16 | [512 128 28 28] | -1 | TRUE | random | fwd | 23333338 | 2310600 | 10.10 |
| CumProd | bfloat16 | [512 256 28 28] | -1 | TRUE | random | fwd | 46674003 | 4619370 | 10.10 |
| CumProd | bfloat16 | [512 256 14 14] | -1 | TRUE | random | fwd | 23306058 | 2302690 | 10.12 |
| CumProd | bfloat16 | [512 512 14 14] | -1 | TRUE | random | fwd | 46625844 | 4603780 | 10.13 |
| CumProd | bfloat16 | [512 512 7 7] | -1 | TRUE | random | fwd | 23304010 | 2295150 | 10.15 |
| CumProd | bfloat16 | [512 1024 7 7] | -1 | TRUE | random | fwd | 46605092 | 4588370 | 10.16 |
| CumProd | bfloat16 | [512 1024 100] | -1 | TRUE | random | fwd | 6709842 | 979597 | 6.85 |
| CumProd | bfloat16 | [1024 1024 7 7] | -1 | TRUE | random | fwd | 93215385 | 9178690 | 10.16 |
| CumProd | bfloat16 | [1024 1024 100] | -1 | TRUE | random | fwd | 13421396 | 1954820 | 6.87 |
Average over all cases:
| type | average |
|---|---|
| float16 | 10.42 |
| float32 | 10.43 |
| bfloat16 | 11.32 |