MIOpen
MIOpen copied to clipboard
Implement UnsortedsegmentSum
-
Added UnsortedSegmentSum forward and backward.
-
Added driver test and gtest for UnsortedSegmentSum.
-
New API is guarded by MIOPEN_BETA_API macro.
-
Average over all cases:
-
UnsortedSegmentSum
| Type | Forward | Backward |
|---|---|---|
| float16 | 1.77 | 1.58 |
| float32 | 1.64 | 1.53 |
| bfloat16 | 1.88 | 1.92 |
FWD-FP16
| op_name | dtype | input_size | direction | rocm_kernel_avg | kernel_duration | improvement over rocm |
|---|---|---|---|---|---|---|
| UnsortedSegmentSum | float16 | [50 100] | fwd | 16864 | 9208 | 1.831450912 |
| UnsortedSegmentSum | float16 | [100 50] | fwd | 21536 | 15376 | 1.40062435 |
| UnsortedSegmentSum | float16 | [100 100] | fwd | 20848 | 14096 | 1.479001135 |
| UnsortedSegmentSum | float16 | [100 300] | fwd | 22448 | 14790 | 1.517782285 |
| UnsortedSegmentSum | float16 | [300 100] | fwd | 43296 | 36833 | 1.175467651 |
| UnsortedSegmentSum | float16 | [200 300] | fwd | 36848 | 30256 | 1.217874141 |
| UnsortedSegmentSum | float16 | [205 350] | fwd | 39280 | 30949 | 1.269184788 |
| UnsortedSegmentSum | float16 | [350 105] | fwd | 54240 | 50806 | 1.067590442 |
| UnsortedSegmentSum | float16 | [405 200] | fwd | 66256 | 58930 | 1.124316986 |
| UnsortedSegmentSum | float16 | [10 10 10] | fwd | 24368 | 6257 | 3.89451814 |
| UnsortedSegmentSum | float16 | [10 10 30] | fwd | 26128 | 6346 | 4.117239206 |
| UnsortedSegmentSum | float16 | [10 30 10] | fwd | 25600 | 6470 | 3.956723338 |
| UnsortedSegmentSum | float16 | [30 10 10] | fwd | 14352 | 8266 | 1.736269054 |
| UnsortedSegmentSum | float16 | [30 30 30] | fwd | 16464 | 9350 | 1.760855615 |
| UnsortedSegmentSum | float16 | [50 100 50] | fwd | 25504 | 19376 | 1.316267547 |
| UnsortedSegmentSum | float16 | [100 50 100] | fwd | 35344 | 29990 | 1.178526175 |
| UnsortedSegmentSum | float16 | [100 100 100] | fwd | 37520 | 34416 | 1.090190609 |
| UnsortedSegmentSum | float16 | [100 100 300] | fwd | 72688 | 68886 | 1.055192637 |
| UnsortedSegmentSum | float16 | [300 100 100] | fwd | 96432 | 90983 | 1.059890309 |
| UnsortedSegmentSum | float16 | [10 10 10 10] | fwd | 27408 | 7235 | 3.788251555 |
| UnsortedSegmentSum | float16 | [10 10 10 30] | fwd | 27296 | 6826 | 3.998828011 |
| UnsortedSegmentSum | float16 | [30 10 10 10] | fwd | 16912 | 9475 | 1.784907652 |
| UnsortedSegmentSum | float16 | [30 30 30 30] | fwd | 26560 | 22807 | 1.164554742 |
| UnsortedSegmentSum | float16 | [50 100 50 100] | fwd | 492846 | 450899 | 1.093029703 |
| UnsortedSegmentSum | float16 | [100 50 100 50] | fwd | 482990 | 450116 | 1.073034507 |
| UnsortedSegmentSum | float16 | [100 100 100 100] | fwd | 2049910 | 1790710 | 1.144747056 |
| UnsortedSegmentSum | float16 | [100 100 300 100] | fwd | 6334755 | 5348770 | 1.184338642 |
| UnsortedSegmentSum | float16 | [300 100 100 100] | fwd | 6149267 | 5351950 | 1.148976915 |
FWD-FP32
| op_name | dtype | input_size | direction | rocm_kernel_avg | kernel_duration | improvement over rocm |
|---|---|---|---|---|---|---|
| UnsortedSegmentSum | float32 | [50 100] | fwd | 12240 | 6862 | 1.78373652 |
| UnsortedSegmentSum | float32 | [100 50] | fwd | 11120 | 8693 | 1.279190153 |
| UnsortedSegmentSum | float32 | [100 100] | fwd | 11808 | 8621 | 1.369678692 |
| UnsortedSegmentSum | float32 | [100 300] | fwd | 13056 | 8710 | 1.498966705 |
| UnsortedSegmentSum | float32 | [300 100] | fwd | 11744 | 14185 | 0.827916814 |
| UnsortedSegmentSum | float32 | [200 300] | fwd | 13120 | 12657 | 1.036580548 |
| UnsortedSegmentSum | float32 | [205 350] | fwd | 13488 | 12852 | 1.049486461 |
| UnsortedSegmentSum | float32 | [350 105] | fwd | 12976 | 17225 | 0.753323657 |
| UnsortedSegmentSum | float32 | [405 200] | fwd | 12960 | 18754 | 0.691052575 |
| UnsortedSegmentSum | float32 | [10 10 10] | fwd | 21136 | 5742 | 3.680947405 |
| UnsortedSegmentSum | float32 | [10 10 30] | fwd | 21072 | 5884 | 3.581237254 |
| UnsortedSegmentSum | float32 | [10 30 10] | fwd | 20880 | 5919 | 3.527622909 |
| UnsortedSegmentSum | float32 | [30 10 10] | fwd | 11920 | 6399 | 1.862791061 |
| UnsortedSegmentSum | float32 | [30 30 30] | fwd | 13664 | 6861 | 1.991546422 |
| UnsortedSegmentSum | float32 | [50 100 50] | fwd | 15488 | 9546 | 1.622459669 |
| UnsortedSegmentSum | float32 | [100 50 100] | fwd | 11872 | 13226 | 0.897625888 |
| UnsortedSegmentSum | float32 | [100 100 100] | fwd | 16752 | 17492 | 0.957694946 |
| UnsortedSegmentSum | float32 | [100 100 300] | fwd | 36272 | 36211 | 1.001684571 |
| UnsortedSegmentSum | float32 | [300 100 100] | fwd | 36080 | 38682 | 0.932733571 |
| UnsortedSegmentSum | float32 | [10 10 10 10] | fwd | 20976 | 5937 | 3.533097524 |
| UnsortedSegmentSum | float32 | [10 10 10 30] | fwd | 23056 | 5884 | 3.918422842 |
| UnsortedSegmentSum | float32 | [30 10 10 10] | fwd | 13360 | 7004 | 1.907481439 |
| UnsortedSegmentSum | float32 | [30 30 30 30] | fwd | 14576 | 14861 | 0.980822287 |
| UnsortedSegmentSum | float32 | [50 100 50 100] | fwd | 350702 | 298282 | 1.175739736 |
| UnsortedSegmentSum | float32 | [100 50 100 50] | fwd | 311775 | 300735 | 1.03671006 |
| UnsortedSegmentSum | float32 | [100 100 100 100] | fwd | 1133147 | 1183410 | 0.957526977 |
| UnsortedSegmentSum | float32 | [100 100 300 100] | fwd | 3500016 | 3563430 | 0.982204225 |
| UnsortedSegmentSum | float32 | [300 100 100 100] | fwd | 3424240 | 3551890 | 0.964061387 |
FWD-BFP16
| op_name | dtype | input_size | direction | rocm_kernel_avg | kernel_duration | improvement over rocm |
|---|---|---|---|---|---|---|
| UnsortedSegmentSum | bfloat16 | [50 100] | fwd | 16800 | 8835 | 1.901528014 |
| UnsortedSegmentSum | bfloat16 | [100 50] | fwd | 23039 | 14346 | 1.605952879 |
| UnsortedSegmentSum | bfloat16 | [100 100] | fwd | 20640 | 13208 | 1.562689279 |
| UnsortedSegmentSum | bfloat16 | [100 300] | fwd | 21120 | 13813 | 1.528994426 |
| UnsortedSegmentSum | bfloat16 | [300 100] | fwd | 42559 | 34967 | 1.217118998 |
| UnsortedSegmentSum | bfloat16 | [200 300] | fwd | 36959 | 29670 | 1.245669026 |
| UnsortedSegmentSum | bfloat16 | [205 350] | fwd | 44160 | 28870 | 1.529615518 |
| UnsortedSegmentSum | bfloat16 | [350 105] | fwd | 55839 | 46345 | 1.204854893 |
| UnsortedSegmentSum | bfloat16 | [405 200] | fwd | 70400 | 55749 | 1.26280292 |
| UnsortedSegmentSum | bfloat16 | [10 10 10] | fwd | 25920 | 6221 | 4.166532712 |
| UnsortedSegmentSum | bfloat16 | [10 10 30] | fwd | 27360 | 6257 | 4.372702573 |
| UnsortedSegmentSum | bfloat16 | [10 30 10] | fwd | 29760 | 6275 | 4.742629482 |
| UnsortedSegmentSum | bfloat16 | [30 10 10] | fwd | 12640 | 7928 | 1.594349142 |
| UnsortedSegmentSum | bfloat16 | [30 30 30] | fwd | 15680 | 9031 | 1.736241834 |
| UnsortedSegmentSum | bfloat16 | [50 100 50] | fwd | 23200 | 19110 | 1.214024071 |
| UnsortedSegmentSum | bfloat16 | [100 50 100] | fwd | 34879 | 30079 | 1.159579773 |
| UnsortedSegmentSum | bfloat16 | [100 100 100] | fwd | 45600 | 33972 | 1.342281879 |
| UnsortedSegmentSum | bfloat16 | [100 100 300] | fwd | 83359 | 68193 | 1.222398193 |
| UnsortedSegmentSum | bfloat16 | [300 100 100] | fwd | 109439 | 88513 | 1.236417249 |
| UnsortedSegmentSum | bfloat16 | [10 10 10 10] | fwd | 27200 | 7146 | 3.806325217 |
| UnsortedSegmentSum | bfloat16 | [10 10 10 30] | fwd | 27840 | 6648 | 4.187725632 |
| UnsortedSegmentSum | bfloat16 | [30 10 10 10] | fwd | 14080 | 9137 | 1.540987195 |
| UnsortedSegmentSum | bfloat16 | [30 30 30 30] | fwd | 30560 | 22630 | 1.350419797 |
| UnsortedSegmentSum | bfloat16 | [50 100 50 100] | fwd | 527355 | 447774 | 1.177725817 |
| UnsortedSegmentSum | bfloat16 | [100 50 100 50] | fwd | 517436 | 446103 | 1.159902534 |
| UnsortedSegmentSum | bfloat16 | [100 100 100 100] | fwd | 2161101 | 1775240 | 1.21735709 |
| UnsortedSegmentSum | bfloat16 | [100 100 300 100] | fwd | 6671944 | 5302680 | 1.258221126 |
| UnsortedSegmentSum | bfloat16 | [300 100 100 100] | fwd | 6511306 | 5300670 | 1.228393014 |
BWD-FP16
| op_name | dtype | input_size | direction | rocm_kernel_avg | kernel_duration | improvement over rocm |
|---|---|---|---|---|---|---|
| UnsortedSegmentSum | float16 | [50 100] | bwd | 5792 | 4088 | 1.416829746 |
| UnsortedSegmentSum | float16 | [100 50] | bwd | 5440 | 3786 | 1.436872689 |
| UnsortedSegmentSum | float16 | [100 100] | bwd | 5744 | 4088 | 1.405088063 |
| UnsortedSegmentSum | float16 | [100 300] | bwd | 6704 | 4639 | 1.445139039 |
| UnsortedSegmentSum | float16 | [300 100] | bwd | 5760 | 4035 | 1.427509294 |
| UnsortedSegmentSum | float16 | [200 300] | bwd | 6752 | 4533 | 1.489521288 |
| UnsortedSegmentSum | float16 | [205 350] | bwd | 6832 | 4817 | 1.418310152 |
| UnsortedSegmentSum | float16 | [350 105] | bwd | 5824 | 4248 | 1.370998117 |
| UnsortedSegmentSum | float16 | [405 200] | bwd | 6224 | 4746 | 1.311420143 |
| UnsortedSegmentSum | float16 | [10 10 10] | bwd | 10768 | 4070 | 2.645700246 |
| UnsortedSegmentSum | float16 | [10 10 30] | bwd | 11344 | 4408 | 2.573502722 |
| UnsortedSegmentSum | float16 | [10 30 10] | bwd | 11152 | 4372 | 2.550777676 |
| UnsortedSegmentSum | float16 | [30 10 10] | bwd | 5824 | 4035 | 1.443370508 |
| UnsortedSegmentSum | float16 | [30 30 30] | bwd | 7472 | 5421 | 1.378343479 |
| UnsortedSegmentSum | float16 | [50 100 50] | bwd | 9216 | 6915 | 1.332754881 |
| UnsortedSegmentSum | float16 | [100 50 100] | bwd | 11408 | 8265 | 1.380278282 |
| UnsortedSegmentSum | float16 | [100 100 100] | bwd | 16544 | 11555 | 1.431761142 |
| UnsortedSegmentSum | float16 | [100 100 300] | bwd | 37040 | 25919 | 1.429067479 |
| UnsortedSegmentSum | float16 | [300 100 100] | bwd | 37008 | 25350 | 1.459881657 |
| UnsortedSegmentSum | float16 | [10 10 10 10] | bwd | 11440 | 5048 | 2.266244057 |
| UnsortedSegmentSum | float16 | [10 10 10 30] | bwd | 11360 | 5475 | 2.074885845 |
| UnsortedSegmentSum | float16 | [30 10 10 10] | bwd | 7568 | 5404 | 1.400444115 |
| UnsortedSegmentSum | float16 | [30 30 30 30] | bwd | 15456 | 11003 | 1.404707807 |
| UnsortedSegmentSum | float16 | [50 100 50 100] | bwd | 293759 | 219671 | 1.337268005 |
| UnsortedSegmentSum | float16 | [100 50 100 50] | bwd | 289439 | 211068 | 1.371306877 |
| UnsortedSegmentSum | float16 | [100 100 100 100] | bwd | 1215786 | 906082 | 1.341805709 |
| UnsortedSegmentSum | float16 | [100 100 300 100] | bwd | 3717855 | 2829140 | 1.314129029 |
| UnsortedSegmentSum | float16 | [300 100 100 100] | bwd | 3652063 | 2697700 | 1.353769137 |
BWD-FP32
| op_name | dtype | input_size | direction | rocm_kernel_avg | kernel_duration | improvement over rocm |
|---|---|---|---|---|---|---|
| UnsortedSegmentSum | float32 | [50 100] | bwd | 6176 | 4444 | 1.389738974 |
| UnsortedSegmentSum | float32 | [100 50] | bwd | 5840 | 4035 | 1.447335812 |
| UnsortedSegmentSum | float32 | [100 100] | bwd | 6384 | 4373 | 1.459867368 |
| UnsortedSegmentSum | float32 | [100 300] | bwd | 7248 | 5333 | 1.359084943 |
| UnsortedSegmentSum | float32 | [300 100] | bwd | 6144 | 4622 | 1.329294678 |
| UnsortedSegmentSum | float32 | [200 300] | bwd | 7152 | 5137 | 1.392252287 |
| UnsortedSegmentSum | float32 | [205 350] | bwd | 7184 | 5102 | 1.408075265 |
| UnsortedSegmentSum | float32 | [350 105] | bwd | 6080 | 4159 | 1.461889877 |
| UnsortedSegmentSum | float32 | [405 200] | bwd | 6816 | 4995 | 1.364564565 |
| UnsortedSegmentSum | float32 | [10 10 10] | bwd | 11664 | 4479 | 2.604152713 |
| UnsortedSegmentSum | float32 | [10 10 30] | bwd | 11776 | 4906 | 2.400326131 |
| UnsortedSegmentSum | float32 | [10 30 10] | bwd | 11360 | 4959 | 2.290784432 |
| UnsortedSegmentSum | float32 | [30 10 10] | bwd | 6240 | 4462 | 1.39847602 |
| UnsortedSegmentSum | float32 | [30 30 30] | bwd | 7536 | 5528 | 1.363241679 |
| UnsortedSegmentSum | float32 | [50 100 50] | bwd | 9296 | 7022 | 1.323839362 |
| UnsortedSegmentSum | float32 | [100 50 100] | bwd | 11760 | 8533 | 1.378178835 |
| UnsortedSegmentSum | float32 | [100 100 100] | bwd | 16880 | 11910 | 1.41729639 |
| UnsortedSegmentSum | float32 | [100 100 300] | bwd | 38448 | 26985 | 1.424791551 |
| UnsortedSegmentSum | float32 | [300 100 100] | bwd | 37280 | 26594 | 1.401819959 |
| UnsortedSegmentSum | float32 | [10 10 10 10] | bwd | 11408 | 5333 | 2.139133696 |
| UnsortedSegmentSum | float32 | [10 10 10 30] | bwd | 11840 | 5599 | 2.114663333 |
| UnsortedSegmentSum | float32 | [30 10 10 10] | bwd | 7456 | 5511 | 1.352930503 |
| UnsortedSegmentSum | float32 | [30 30 30 30] | bwd | 15936 | 11235 | 1.418424566 |
| UnsortedSegmentSum | float32 | [50 100 50 100] | bwd | 357614 | 249270 | 1.434645164 |
| UnsortedSegmentSum | float32 | [100 50 100 50] | bwd | 338830 | 248364 | 1.364247637 |
| UnsortedSegmentSum | float32 | [100 100 100 100] | bwd | 1169179 | 1026289 | 1.139229788 |
| UnsortedSegmentSum | float32 | [100 100 300 100] | bwd | 3556864 | 3080350 | 1.154694759 |
| UnsortedSegmentSum | float32 | [300 100 100 100] | bwd | 3518207 | 3016550 | 1.16630157 |
BWD-BFP16
| op_name | dtype | input_size | direction | rocm_kernel_avg | kernel_duration | improvement over rocm |
|---|---|---|---|---|---|---|
| UnsortedSegmentSum | bfloat16 | [50 100] | bwd | 6400 | 3928 | 1.629327902 |
| UnsortedSegmentSum | bfloat16 | [100 50] | bwd | 5920 | 3733 | 1.58585588 |
| UnsortedSegmentSum | bfloat16 | [100 100] | bwd | 12479 | 3928 | 3.176934827 |
| UnsortedSegmentSum | bfloat16 | [100 300] | bwd | 13120 | 4711 | 2.784971344 |
| UnsortedSegmentSum | bfloat16 | [300 100] | bwd | 12640 | 4088 | 3.091976517 |
| UnsortedSegmentSum | bfloat16 | [200 300] | bwd | 7200 | 4586 | 1.569995639 |
| UnsortedSegmentSum | bfloat16 | [205 350] | bwd | 7840 | 4853 | 1.61549557 |
| UnsortedSegmentSum | bfloat16 | [350 105] | bwd | 6720 | 4195 | 1.601907032 |
| UnsortedSegmentSum | bfloat16 | [405 200] | bwd | 7039 | 4782 | 1.471978252 |
| UnsortedSegmentSum | bfloat16 | [10 10 10] | bwd | 12479 | 4177 | 2.987550874 |
| UnsortedSegmentSum | bfloat16 | [10 10 30] | bwd | 16640 | 4462 | 3.729269386 |
| UnsortedSegmentSum | bfloat16 | [10 30 10] | bwd | 17440 | 4479 | 3.893726278 |
| UnsortedSegmentSum | bfloat16 | [30 10 10] | bwd | 6560 | 4017 | 1.633059497 |
| UnsortedSegmentSum | bfloat16 | [30 30 30] | bwd | 12960 | 5351 | 2.421977201 |
| UnsortedSegmentSum | bfloat16 | [50 100 50] | bwd | 9440 | 6897 | 1.368711034 |
| UnsortedSegmentSum | bfloat16 | [100 50 100] | bwd | 11680 | 8177 | 1.428396723 |
| UnsortedSegmentSum | bfloat16 | [100 100 100] | bwd | 16800 | 11555 | 1.453916054 |
| UnsortedSegmentSum | bfloat16 | [100 100 300] | bwd | 36640 | 26168 | 1.40018343 |
| UnsortedSegmentSum | bfloat16 | [300 100 100] | bwd | 35840 | 25314 | 1.415817334 |
| UnsortedSegmentSum | bfloat16 | [10 10 10 10] | bwd | 10400 | 5119 | 2.031646806 |
| UnsortedSegmentSum | bfloat16 | [10 10 10 30] | bwd | 10239 | 5351 | 1.913474117 |
| UnsortedSegmentSum | bfloat16 | [30 10 10 10] | bwd | 7520 | 5315 | 1.414863594 |
| UnsortedSegmentSum | bfloat16 | [30 30 30 30] | bwd | 15840 | 10915 | 1.451213926 |
| UnsortedSegmentSum | bfloat16 | [50 100 50 100] | bwd | 290077 | 219976 | 1.318675674 |
| UnsortedSegmentSum | bfloat16 | [100 50 100 50] | bwd | 285437 | 211265 | 1.35108513 |
| UnsortedSegmentSum | bfloat16 | [100 100 100 100] | bwd | 1193110 | 906268 | 1.316509024 |
| UnsortedSegmentSum | bfloat16 | [100 100 300 100] | bwd | 3684289 | 2827510 | 1.303015374 |
| UnsortedSegmentSum | bfloat16 | [300 100 100 100] | bwd | 3590690 | 2695390 | 1.332159725 |