Implement SigmoidFocalLoss operation
This PR implement torchvision.ops.sigmoid_focal_loss operation. There is no constraint here, MIOpen is faster than ROCm in all cases.
- [x] Added SigmoidFocalLoss operation with forward and backward kernels.
- [x] Added driver test and gtest for 4 kernels.
- [x] Compared with ROCm.
Average improvement over ROCm
Reduced kernels:
| type | fwd | bwd |
|---|---|---|
| float32 | 2.74 | 4.94 |
| float16 | 2.41 | 4.40 |
| bfloat16 | 2.58 | 4.70 |
Unreduced kernels:
| type | fwd | bwd |
|---|---|---|
| float32 | 5.09 | 3.86 |
| float16 | 4.83 | 3.46 |
| bfloat16 | 5.10 | 3.92 |
Detail benchmark
Float32
| dtype | size | is_contiguous | reduction | direction | ROCm | MIOpen | improvement |
|---|---|---|---|---|---|---|---|
| float32 | [20 30] | TRUE | none | fwd | 81551 | 13582 | 6.004343985 |
| float32 | [20 30] | TRUE | none | bwd | 77295 | 18933 | 4.08255427 |
| float32 | [20 30] | TRUE | sum | fwd | 91327 | 34649 | 2.635775924 |
| float32 | [20 30] | TRUE | sum | bwd | 77695 | 17742 | 4.379156803 |
| float32 | [20 30] | TRUE | mean | fwd | 90847 | 35912 | 2.529711517 |
| float32 | [20 30] | TRUE | mean | bwd | 80416 | 19164 | 4.196201211 |
| float32 | [20 30] | FALSE | none | fwd | 75296 | 13211 | 5.699492847 |
| float32 | [20 30] | FALSE | none | bwd | 75040 | 18883 | 3.973944818 |
| float32 | [20 30] | FALSE | sum | fwd | 88623 | 35545 | 2.493262062 |
| float32 | [20 30] | FALSE | sum | bwd | 75359 | 18208 | 4.138785149 |
| float32 | [20 30] | FALSE | mean | fwd | 87808 | 35313 | 2.486563022 |
| float32 | [20 30] | FALSE | mean | bwd | 135519 | 19630 | 6.903667855 |
| float32 | [5 10 10] | TRUE | none | fwd | 76575 | 12784 | 5.989909262 |
| float32 | [5 10 10] | TRUE | none | bwd | 77695 | 18225 | 4.263100137 |
| float32 | [5 10 10] | TRUE | sum | fwd | 83167 | 35562 | 2.338647995 |
| float32 | [5 10 10] | TRUE | sum | bwd | 77231 | 17141 | 4.505629777 |
| float32 | [5 10 10] | TRUE | mean | fwd | 88910 | 37323 | 2.38217721 |
| float32 | [5 10 10] | TRUE | mean | bwd | 83183 | 19292 | 4.311787269 |
| float32 | [5 10 10] | FALSE | none | fwd | 80991 | 13585 | 5.961796099 |
| float32 | [5 10 10] | FALSE | none | bwd | 90192 | 18723 | 4.817176734 |
| float32 | [5 10 10] | FALSE | sum | fwd | 95487 | 37251 | 2.563340581 |
| float32 | [5 10 10] | FALSE | sum | bwd | 81727 | 17319 | 4.718921416 |
| float32 | [5 10 10] | FALSE | mean | fwd | 85215 | 36825 | 2.314052953 |
| float32 | [5 10 10] | FALSE | mean | bwd | 144479 | 17229 | 8.385803007 |
| float32 | [2 5 10 10] | TRUE | none | fwd | 87183 | 13549 | 6.434644623 |
| float32 | [2 5 10 10] | TRUE | none | bwd | 80527 | 19488 | 4.132132594 |
| float32 | [2 5 10 10] | TRUE | sum | fwd | 87247 | 34940 | 2.497052089 |
| float32 | [2 5 10 10] | TRUE | sum | bwd | 75136 | 19790 | 3.796664982 |
| float32 | [2 5 10 10] | TRUE | mean | fwd | 84383 | 37340 | 2.259855383 |
| float32 | [2 5 10 10] | TRUE | mean | bwd | 75648 | 20217 | 3.741801454 |
| float32 | [2 5 10 10] | FALSE | none | fwd | 84047 | 14011 | 5.998643923 |
| float32 | [2 5 10 10] | FALSE | none | bwd | 79391 | 20003 | 3.968954657 |
| float32 | [2 5 10 10] | FALSE | sum | fwd | 94831 | 40594 | 2.33608415 |
| float32 | [2 5 10 10] | FALSE | sum | bwd | 79824 | 20412 | 3.9106408 |
| float32 | [2 5 10 10] | FALSE | mean | fwd | 90783 | 34122 | 2.660541586 |
| float32 | [2 5 10 10] | FALSE | mean | bwd | 150190 | 20430 | 7.351443955 |
| float32 | [25 300] | TRUE | none | fwd | 79807 | 9566 | 8.3427765 |
| float32 | [25 300] | TRUE | none | bwd | 79967 | 12215 | 6.546623005 |
| float32 | [25 300] | TRUE | sum | fwd | 86399 | 35882 | 2.407864668 |
| float32 | [25 300] | TRUE | sum | bwd | 80415 | 12535 | 6.415237335 |
| float32 | [25 300] | TRUE | mean | fwd | 88624 | 34495 | 2.56918394 |
| float32 | [25 300] | TRUE | mean | bwd | 81519 | 13246 | 6.154235241 |
| float32 | [25 300] | FALSE | none | fwd | 79567 | 9637 | 8.256407596 |
| float32 | [25 300] | FALSE | none | bwd | 79184 | 12286 | 6.445059417 |
| float32 | [25 300] | FALSE | sum | fwd | 85311 | 32699 | 2.608978868 |
| float32 | [25 300] | FALSE | sum | bwd | 79583 | 12962 | 6.139716093 |
| float32 | [25 300] | FALSE | mean | fwd | 92607 | 31614 | 2.929303473 |
| float32 | [25 300] | FALSE | mean | bwd | 136607 | 13229 | 10.32632852 |
| float32 | [25 100 100] | TRUE | none | fwd | 106431 | 16411 | 6.485345195 |
| float32 | [25 100 100] | TRUE | none | bwd | 122607 | 24288 | 5.048048419 |
| float32 | [25 100 100] | TRUE | sum | fwd | 121695 | 41287 | 2.947537966 |
| float32 | [25 100 100] | TRUE | sum | bwd | 119615 | 24235 | 4.935630287 |
| float32 | [25 100 100] | TRUE | mean | fwd | 122863 | 43492 | 2.824956314 |
| float32 | [25 100 100] | TRUE | mean | bwd | 123439 | 24484 | 5.041619017 |
| float32 | [25 100 100] | FALSE | none | fwd | 106975 | 23168 | 4.617360152 |
| float32 | [25 100 100] | FALSE | none | bwd | 120671 | 28591 | 4.220593893 |
| float32 | [25 100 100] | FALSE | sum | fwd | 123199 | 46870 | 2.628525709 |
| float32 | [25 100 100] | FALSE | sum | bwd | 118879 | 27951 | 4.253121534 |
| float32 | [25 100 100] | FALSE | mean | fwd | 126463 | 48292 | 2.618715315 |
| float32 | [25 100 100] | FALSE | mean | bwd | 231566 | 28218 | 8.206322206 |
| float32 | [100 20 20 20] | TRUE | none | fwd | 200398 | 41411 | 4.839245611 |
| float32 | [100 20 20 20] | TRUE | none | bwd | 241470 | 65771 | 3.671374922 |
| float32 | [100 20 20 20] | TRUE | sum | fwd | 214286 | 69522 | 3.082276114 |
| float32 | [100 20 20 20] | TRUE | sum | bwd | 238925 | 65931 | 3.623864343 |
| float32 | [100 20 20 20] | TRUE | mean | fwd | 214414 | 69736 | 3.074652977 |
| float32 | [100 20 20 20] | TRUE | mean | bwd | 247566 | 65700 | 3.768127854 |
| float32 | [100 20 20 20] | FALSE | none | fwd | 200734 | 51990 | 3.861011733 |
| float32 | [100 20 20 20] | FALSE | none | bwd | 240942 | 69647 | 3.459474206 |
| float32 | [100 20 20 20] | FALSE | sum | fwd | 216590 | 83871 | 2.582418238 |
| float32 | [100 20 20 20] | FALSE | sum | bwd | 241663 | 69238 | 3.490323233 |
| float32 | [100 20 20 20] | FALSE | mean | fwd | 217742 | 81684 | 2.665662798 |
| float32 | [100 20 20 20] | FALSE | mean | bwd | 549323 | 69131 | 7.946116793 |
| float32 | [100 10 10 10 10] | TRUE | none | fwd | 249582 | 50337 | 4.958221587 |
| float32 | [100 10 10 10 10] | TRUE | none | bwd | 291789 | 80564 | 3.621828608 |
| float32 | [100 10 10 10 10] | TRUE | sum | fwd | 265982 | 78181 | 3.402130953 |
| float32 | [100 10 10 10 10] | TRUE | sum | bwd | 286189 | 79888 | 3.582377829 |
| float32 | [100 10 10 10 10] | TRUE | mean | fwd | 265854 | 79052 | 3.363026868 |
| float32 | [100 10 10 10 10] | TRUE | mean | bwd | 293070 | 80101 | 3.658755821 |
| float32 | [100 10 10 10 10] | FALSE | none | fwd | 248734 | 62356 | 3.988934505 |
| float32 | [100 10 10 10 10] | FALSE | none | bwd | 290846 | 84724 | 3.43286436 |
| float32 | [100 10 10 10 10] | FALSE | sum | fwd | 263806 | 96121 | 2.744519928 |
| float32 | [100 10 10 10 10] | FALSE | sum | bwd | 283677 | 83924 | 3.380165388 |
| float32 | [100 10 10 10 10] | FALSE | mean | fwd | 267854 | 94539 | 2.833264579 |
| float32 | [100 10 10 10 10] | FALSE | mean | bwd | 636091 | 83977 | 7.574585899 |
| float32 | [2000 3000] | TRUE | none | fwd | 1285317 | 278904 | 4.608456673 |
| float32 | [2000 3000] | TRUE | none | bwd | 1456660 | 457011 | 3.187363105 |
| float32 | [2000 3000] | TRUE | sum | fwd | 1331780 | 349973 | 3.805379272 |
| float32 | [2000 3000] | TRUE | sum | bwd | 1433027 | 452334 | 3.168072707 |
| float32 | [2000 3000] | TRUE | mean | fwd | 1334900 | 347252 | 3.844182323 |
| float32 | [2000 3000] | TRUE | mean | bwd | 1487507 | 451605 | 3.293823142 |
| float32 | [2000 3000] | FALSE | none | fwd | 1304789 | 328934 | 3.966719767 |
| float32 | [2000 3000] | FALSE | none | bwd | 1452675 | 469327 | 3.095229978 |
| float32 | [2000 3000] | FALSE | sum | fwd | 1336884 | 401940 | 3.326078519 |
| float32 | [2000 3000] | FALSE | sum | bwd | 1421475 | 463442 | 3.067212294 |
| float32 | [2000 3000] | FALSE | mean | fwd | 1338228 | 399895 | 3.346448443 |
| float32 | [2000 3000] | FALSE | mean | bwd | 3341635 | 462854 | 7.219630812 |
| float32 | [25 1000 1000] | TRUE | none | fwd | 5138451 | 1145640 | 4.485223107 |
| float32 | [25 1000 1000] | TRUE | none | bwd | 5819901 | 1889040 | 3.080877589 |
| float32 | [25 1000 1000] | TRUE | sum | fwd | 5243106 | 1367510 | 3.834053133 |
| float32 | [25 1000 1000] | TRUE | sum | bwd | 5678942 | 4201130 | 1.351765358 |
| float32 | [25 1000 1000] | TRUE | mean | fwd | 5244754 | 1366410 | 3.838345738 |
| float32 | [25 1000 1000] | TRUE | mean | bwd | 5918332 | 1862170 | 3.178191035 |
| float32 | [25 1000 1000] | FALSE | none | fwd | 5152050 | 4529510 | 1.137440915 |
| float32 | [25 1000 1000] | FALSE | none | bwd | 5811853 | 4624380 | 1.256785342 |
| float32 | [25 1000 1000] | FALSE | sum | fwd | 5250178 | 4761350 | 1.102665841 |
| float32 | [25 1000 1000] | FALSE | sum | bwd | 5668574 | 4549310 | 1.246029398 |
| float32 | [25 1000 1000] | FALSE | mean | fwd | 5241474 | 4763050 | 1.100444883 |
| float32 | [25 1000 1000] | FALSE | mean | bwd | 35102137 | 4549250 | 7.716027257 |
| float32 | [10 100 100 100] | TRUE | none | fwd | 2106141 | 460996 | 4.568675216 |
| float32 | [10 100 100 100] | TRUE | none | bwd | 2372939 | 758370 | 3.128999037 |
| float32 | [10 100 100 100] | TRUE | sum | fwd | 2164653 | 561128 | 3.857681313 |
| float32 | [10 100 100 100] | TRUE | sum | bwd | 2299532 | 747898 | 3.074659914 |
| float32 | [10 100 100 100] | TRUE | mean | fwd | 2161885 | 562141 | 3.84580559 |
| float32 | [10 100 100 100] | TRUE | mean | bwd | 2419099 | 748982 | 3.229849315 |
| float32 | [10 100 100 100] | FALSE | none | fwd | 2117037 | 1324770 | 1.598041169 |
| float32 | [10 100 100 100] | FALSE | none | bwd | 2369579 | 1379190 | 1.718094679 |
| float32 | [10 100 100 100] | FALSE | sum | fwd | 2168365 | 1431280 | 1.514983092 |
| float32 | [10 100 100 100] | FALSE | sum | bwd | 2303644 | 1356110 | 1.698714706 |
| float32 | [10 100 100 100] | FALSE | mean | fwd | 2168701 | 1433990 | 1.51235434 |
| float32 | [10 100 100 100] | FALSE | mean | bwd | 14042708 | 1356220 | 10.35429945 |
| float32 | [10 100 100 100] | FALSE | mean | bwd | 14042708 | 1356220 | 10.35429945 |
Float16
| dtype | size | is_contiguous | reduction | direction | ROCm | MIOpen | improvement |
|---|---|---|---|---|---|---|---|
| float16 | [20 30] | TRUE | none | fwd | 90351 | 12266 | 7.365970977 |
| float16 | [20 30] | TRUE | none | bwd | 85887 | 19147 | 4.48566355 |
| float16 | [20 30] | TRUE | sum | fwd | 96192 | 36338 | 2.647146238 |
| float16 | [20 30] | TRUE | sum | bwd | 83983 | 20498 | 4.097131427 |
| float16 | [20 30] | TRUE | mean | fwd | 95183 | 37405 | 2.544659805 |
| float16 | [20 30] | TRUE | mean | bwd | 83263 | 19929 | 4.177981836 |
| float16 | [20 30] | FALSE | none | fwd | 89791 | 12802 | 7.013825965 |
| float16 | [20 30] | FALSE | none | bwd | 85296 | 21017 | 4.058428891 |
| float16 | [20 30] | FALSE | sum | fwd | 98367 | 33802 | 2.910094077 |
| float16 | [20 30] | FALSE | sum | bwd | 80767 | 20217 | 3.995004204 |
| float16 | [20 30] | FALSE | mean | fwd | 96735 | 34247 | 2.824626975 |
| float16 | [20 30] | FALSE | mean | bwd | 135951 | 20110 | 6.760367976 |
| float16 | [5 10 10] | TRUE | none | fwd | 84943 | 14634 | 5.804496378 |
| float16 | [5 10 10] | TRUE | none | bwd | 85359 | 18492 | 4.615996106 |
| float16 | [5 10 10] | TRUE | sum | fwd | 93119 | 37785 | 2.464443562 |
| float16 | [5 10 10] | TRUE | sum | bwd | 80015 | 20021 | 3.996553619 |
| float16 | [5 10 10] | TRUE | mean | fwd | 91296 | 36327 | 2.513172021 |
| float16 | [5 10 10] | TRUE | mean | bwd | 81935 | 20164 | 4.063429875 |
| float16 | [5 10 10] | FALSE | none | fwd | 87808 | 13567 | 6.472175131 |
| float16 | [5 10 10] | FALSE | none | bwd | 83792 | 18581 | 4.509552769 |
| float16 | [5 10 10] | FALSE | sum | fwd | 95183 | 41946 | 2.269179421 |
| float16 | [5 10 10] | FALSE | sum | bwd | 80559 | 19097 | 4.218411269 |
| float16 | [5 10 10] | FALSE | mean | fwd | 94063 | 37020 | 2.5408698 |
| float16 | [5 10 10] | FALSE | mean | bwd | 143743 | 20181 | 7.122689659 |
| float16 | [2 5 10 10] | TRUE | none | fwd | 96864 | 14242 | 6.801291953 |
| float16 | [2 5 10 10] | TRUE | none | bwd | 88655 | 19594 | 4.524599367 |
| float16 | [2 5 10 10] | TRUE | sum | fwd | 106239 | 34104 | 3.115147783 |
| float16 | [2 5 10 10] | TRUE | sum | bwd | 89391 | 20893 | 4.278514335 |
| float16 | [2 5 10 10] | TRUE | mean | fwd | 105759 | 34370 | 3.077073029 |
| float16 | [2 5 10 10] | TRUE | mean | bwd | 91807 | 17923 | 5.122300954 |
| float16 | [2 5 10 10] | FALSE | none | fwd | 95023 | 14100 | 6.739219858 |
| float16 | [2 5 10 10] | FALSE | none | bwd | 91327 | 19790 | 4.614805457 |
| float16 | [2 5 10 10] | FALSE | sum | fwd | 102991 | 35046 | 2.938737659 |
| float16 | [2 5 10 10] | FALSE | sum | bwd | 90591 | 19772 | 4.581782318 |
| float16 | [2 5 10 10] | FALSE | mean | fwd | 103663 | 34672 | 2.989818874 |
| float16 | [2 5 10 10] | FALSE | mean | bwd | 146798 | 20430 | 7.185413607 |
| float16 | [25 300] | TRUE | none | fwd | 90751 | 9370 | 9.685272145 |
| float16 | [25 300] | TRUE | none | bwd | 91759 | 11735 | 7.819258628 |
| float16 | [25 300] | TRUE | sum | fwd | 100719 | 31934 | 3.153973821 |
| float16 | [25 300] | TRUE | sum | bwd | 90767 | 12784 | 7.100046934 |
| float16 | [25 300] | TRUE | mean | fwd | 97984 | 34121 | 2.871662612 |
| float16 | [25 300] | TRUE | mean | bwd | 93279 | 12766 | 7.306830644 |
| float16 | [25 300] | FALSE | none | fwd | 90191 | 9566 | 9.428287686 |
| float16 | [25 300] | FALSE | none | bwd | 91775 | 12429 | 7.383940784 |
| float16 | [25 300] | FALSE | sum | fwd | 100335 | 34299 | 2.925303945 |
| float16 | [25 300] | FALSE | sum | bwd | 89935 | 12891 | 6.976572803 |
| float16 | [25 300] | FALSE | mean | fwd | 99679 | 32006 | 3.114384803 |
| float16 | [25 300] | FALSE | mean | bwd | 136015 | 12891 | 10.55115972 |
| float16 | [25 100 100] | TRUE | none | fwd | 95295 | 16305 | 5.844526219 |
| float16 | [25 100 100] | TRUE | none | bwd | 105183 | 23933 | 4.394894079 |
| float16 | [25 100 100] | TRUE | sum | fwd | 109311 | 45839 | 2.38467244 |
| float16 | [25 100 100] | TRUE | sum | bwd | 103535 | 24217 | 4.275302473 |
| float16 | [25 100 100] | TRUE | mean | fwd | 112318 | 38975 | 2.881796023 |
| float16 | [25 100 100] | TRUE | mean | bwd | 107391 | 24306 | 4.41829178 |
| float16 | [25 100 100] | FALSE | none | fwd | 93999 | 20323 | 4.625252177 |
| float16 | [25 100 100] | FALSE | none | bwd | 104559 | 28147 | 3.714747575 |
| float16 | [25 100 100] | FALSE | sum | fwd | 111007 | 45554 | 2.436822233 |
| float16 | [25 100 100] | FALSE | sum | bwd | 103519 | 28911 | 3.580609457 |
| float16 | [25 100 100] | FALSE | mean | fwd | 111871 | 44523 | 2.51265638 |
| float16 | [25 100 100] | FALSE | mean | bwd | 213710 | 28787 | 7.423837149 |
| float16 | [100 20 20 20] | TRUE | none | fwd | 149903 | 40967 | 3.659115874 |
| float16 | [100 20 20 20] | TRUE | none | bwd | 163150 | 65735 | 2.481935042 |
| float16 | [100 20 20 20] | TRUE | sum | fwd | 167295 | 69113 | 2.420601045 |
| float16 | [100 20 20 20] | TRUE | sum | bwd | 164862 | 65664 | 2.510690789 |
| float16 | [100 20 20 20] | TRUE | mean | fwd | 160399 | 67264 | 2.384618815 |
| float16 | [100 20 20 20] | TRUE | mean | bwd | 167806 | 65735 | 2.552764889 |
| float16 | [100 20 20 20] | FALSE | none | fwd | 149711 | 50870 | 2.943011598 |
| float16 | [100 20 20 20] | FALSE | none | bwd | 163486 | 69931 | 2.337818707 |
| float16 | [100 20 20 20] | FALSE | sum | fwd | 163727 | 79106 | 2.069716583 |
| float16 | [100 20 20 20] | FALSE | sum | bwd | 167295 | 70055 | 2.388052245 |
| float16 | [100 20 20 20] | FALSE | mean | fwd | 164383 | 79586 | 2.06547634 |
| float16 | [100 20 20 20] | FALSE | mean | bwd | 471164 | 69931 | 6.737555591 |
| float16 | [100 10 10 10 10] | TRUE | none | fwd | 196527 | 50479 | 3.893242735 |
| float16 | [100 10 10 10 10] | TRUE | none | bwd | 211629 | 80492 | 2.629192963 |
| float16 | [100 10 10 10 10] | TRUE | sum | fwd | 205598 | 82413 | 2.494727774 |
| float16 | [100 10 10 10 10] | TRUE | sum | bwd | 209983 | 80155 | 2.619711808 |
| float16 | [100 10 10 10 10] | TRUE | mean | fwd | 209262 | 79159 | 2.643565482 |
| float16 | [100 10 10 10 10] | TRUE | mean | bwd | 215262 | 80048 | 2.689161503 |
| float16 | [100 10 10 10 10] | FALSE | none | fwd | 192430 | 62445 | 3.081591801 |
| float16 | [100 10 10 10 10] | FALSE | none | bwd | 206718 | 85239 | 2.425157498 |
| float16 | [100 10 10 10 10] | FALSE | sum | fwd | 207790 | 93134 | 2.231086392 |
| float16 | [100 10 10 10 10] | FALSE | sum | bwd | 207903 | 85150 | 2.441608925 |
| float16 | [100 10 10 10 10] | FALSE | mean | fwd | 209119 | 98681 | 2.119141476 |
| float16 | [100 10 10 10 10] | FALSE | mean | bwd | 558955 | 85133 | 6.565667837 |
| float16 | [2000 3000] | TRUE | none | fwd | 831897 | 279953 | 2.971559512 |
| float16 | [2000 3000] | TRUE | none | bwd | 863288 | 459143 | 1.880215968 |
| float16 | [2000 3000] | TRUE | sum | fwd | 860440 | 346806 | 2.481041274 |
| float16 | [2000 3000] | TRUE | sum | bwd | 863785 | 453293 | 1.905577629 |
| float16 | [2000 3000] | TRUE | mean | fwd | 863736 | 346254 | 2.494515587 |
| float16 | [2000 3000] | TRUE | mean | bwd | 890952 | 453328 | 1.965358416 |
| float16 | [2000 3000] | FALSE | none | fwd | 831721 | 322728 | 2.577157854 |
| float16 | [2000 3000] | FALSE | none | bwd | 859304 | 471246 | 1.823472242 |
| float16 | [2000 3000] | FALSE | sum | fwd | 862328 | 396551 | 2.174570232 |
| float16 | [2000 3000] | FALSE | sum | bwd | 866328 | 465929 | 1.859356254 |
| float16 | [2000 3000] | FALSE | mean | fwd | 864360 | 392852 | 2.200217894 |
| float16 | [2000 3000] | FALSE | mean | bwd | 2779080 | 466480 | 5.95755445 |
| float16 | [25 1000 1000] | TRUE | none | fwd | 3194948 | 1149220 | 2.780101286 |
| float16 | [25 1000 1000] | TRUE | none | bwd | 3291667 | 1894540 | 1.737449196 |
| float16 | [25 1000 1000] | TRUE | sum | fwd | 3247044 | 1366720 | 2.37579314 |
| float16 | [25 1000 1000] | TRUE | sum | bwd | 3301987 | 1870760 | 1.765051102 |
| float16 | [25 1000 1000] | TRUE | mean | fwd | 3251459 | 1371580 | 2.370593768 |
| float16 | [25 1000 1000] | TRUE | mean | bwd | 3396194 | 1871180 | 1.815001229 |
| float16 | [25 1000 1000] | FALSE | none | fwd | 3189828 | 3500840 | 0.91116075 |
| float16 | [25 1000 1000] | FALSE | none | bwd | 3271347 | 3584900 | 0.912535078 |
| float16 | [25 1000 1000] | FALSE | sum | fwd | 3241283 | 3733260 | 0.868217858 |
| float16 | [25 1000 1000] | FALSE | sum | bwd | 3314339 | 3548660 | 0.93396916 |
| float16 | [25 1000 1000] | FALSE | mean | fwd | 3246595 | 3704060 | 0.876496331 |
| float16 | [25 1000 1000] | FALSE | mean | bwd | 26391238 | 3550170 | 7.433795565 |
| float16 | [10 100 100 100] | TRUE | none | fwd | 1324852 | 462684 | 2.863405694 |
| float16 | [10 100 100 100] | TRUE | none | bwd | 1353908 | 760004 | 1.781448519 |
| float16 | [10 100 100 100] | TRUE | sum | fwd | 1360420 | 562407 | 2.418924373 |
| float16 | [10 100 100 100] | TRUE | sum | bwd | 1376628 | 751470 | 1.83191345 |
| float16 | [10 100 100 100] | TRUE | mean | fwd | 1359620 | 561357 | 2.422023775 |
| float16 | [10 100 100 100] | TRUE | mean | bwd | 1423780 | 750829 | 1.896277315 |
| float16 | [10 100 100 100] | FALSE | none | fwd | 1327300 | 1130510 | 1.174071879 |
| float16 | [10 100 100 100] | FALSE | none | bwd | 1375716 | 1240250 | 1.109224753 |
| float16 | [10 100 100 100] | FALSE | sum | fwd | 1363524 | 1235410 | 1.103701605 |
| float16 | [10 100 100 100] | FALSE | sum | bwd | 1381716 | 1209700 | 1.142197239 |
| float16 | [10 100 100 100] | FALSE | mean | fwd | 1370484 | 1234720 | 1.109955294 |
| float16 | [10 100 100 100] | FALSE | mean | bwd | 9185998 | 1210570 | 7.588159297 |
| float32 | [10 100 100 100] | FALSE | mean | bwd | 14042708 | 1356220 | 10.35429945 |
BFloat16
| dtype | size | is_contiguous | reduction | direction | ROCm | MIOpen | improvement |
|---|---|---|---|---|---|---|---|
| bfloat16 | [20 30] | TRUE | none | fwd | 96207 | 14631 | 6.575558745 |
| bfloat16 | [20 30] | TRUE | none | bwd | 97151 | 19395 | 5.009074504 |
| bfloat16 | [20 30] | TRUE | sum | fwd | 102928 | 35538 | 2.896280038 |
| bfloat16 | [20 30] | TRUE | sum | bwd | 91967 | 19573 | 4.69866653 |
| bfloat16 | [20 30] | TRUE | mean | fwd | 99871 | 34827 | 2.867631435 |
| bfloat16 | [20 30] | TRUE | mean | bwd | 93791 | 20000 | 4.68955 |
| bfloat16 | [20 30] | FALSE | none | fwd | 96815 | 14065 | 6.883398507 |
| bfloat16 | [20 30] | FALSE | none | bwd | 97199 | 20555 | 4.728727803 |
| bfloat16 | [20 30] | FALSE | sum | fwd | 101455 | 35082 | 2.891938886 |
| bfloat16 | [20 30] | FALSE | sum | bwd | 89583 | 20413 | 4.388526919 |
| bfloat16 | [20 30] | FALSE | mean | fwd | 101567 | 34638 | 2.932242046 |
| bfloat16 | [20 30] | FALSE | mean | bwd | 146815 | 21533 | 6.8181396 |
| bfloat16 | [5 10 10] | TRUE | none | fwd | 92959 | 14118 | 6.584431223 |
| bfloat16 | [5 10 10] | TRUE | none | bwd | 96559 | 19950 | 4.840050125 |
| bfloat16 | [5 10 10] | TRUE | sum | fwd | 98318 | 37554 | 2.618043351 |
| bfloat16 | [5 10 10] | TRUE | sum | bwd | 90863 | 22155 | 4.101241255 |
| bfloat16 | [5 10 10] | TRUE | mean | fwd | 97791 | 36487 | 2.680160057 |
| bfloat16 | [5 10 10] | TRUE | mean | bwd | 93711 | 19346 | 4.843947069 |
| bfloat16 | [5 10 10] | FALSE | none | fwd | 97967 | 13514 | 7.249297025 |
| bfloat16 | [5 10 10] | FALSE | none | bwd | 98847 | 17478 | 5.655509784 |
| bfloat16 | [5 10 10] | FALSE | sum | fwd | 103247 | 38940 | 2.65143811 |
| bfloat16 | [5 10 10] | FALSE | sum | bwd | 90495 | 20412 | 4.433421517 |
| bfloat16 | [5 10 10] | FALSE | mean | fwd | 102095 | 38194 | 2.673063832 |
| bfloat16 | [5 10 10] | FALSE | mean | bwd | 150063 | 21764 | 6.895010108 |
| bfloat16 | [2 5 10 10] | TRUE | none | fwd | 107023 | 14829 | 7.217142086 |
| bfloat16 | [2 5 10 10] | TRUE | none | bwd | 108798 | 19754 | 5.507644021 |
| bfloat16 | [2 5 10 10] | TRUE | sum | fwd | 114687 | 37109 | 3.090544073 |
| bfloat16 | [2 5 10 10] | TRUE | sum | bwd | 103999 | 20430 | 5.090504161 |
| bfloat16 | [2 5 10 10] | TRUE | mean | fwd | 115327 | 35491 | 3.249471697 |
| bfloat16 | [2 5 10 10] | TRUE | mean | bwd | 107359 | 20661 | 5.196215091 |
| bfloat16 | [2 5 10 10] | FALSE | none | fwd | 106335 | 14171 | 7.503704749 |
| bfloat16 | [2 5 10 10] | FALSE | none | bwd | 111647 | 20359 | 5.483913748 |
| bfloat16 | [2 5 10 10] | FALSE | sum | fwd | 112079 | 35615 | 3.14696055 |
| bfloat16 | [2 5 10 10] | FALSE | sum | bwd | 103039 | 20306 | 5.074313011 |
| bfloat16 | [2 5 10 10] | FALSE | mean | fwd | 113407 | 35171 | 3.224446277 |
| bfloat16 | [2 5 10 10] | FALSE | mean | bwd | 159327 | 20519 | 7.764852088 |
| bfloat16 | [25 300] | TRUE | none | fwd | 97855 | 9530 | 10.26810073 |
| bfloat16 | [25 300] | TRUE | none | bwd | 105823 | 12197 | 8.676149873 |
| bfloat16 | [25 300] | TRUE | sum | fwd | 106463 | 33019 | 3.224295103 |
| bfloat16 | [25 300] | TRUE | sum | bwd | 101424 | 12819 | 7.912005617 |
| bfloat16 | [25 300] | TRUE | mean | fwd | 105791 | 33979 | 3.113422997 |
| bfloat16 | [25 300] | TRUE | mean | bwd | 107375 | 13229 | 8.11663769 |
| bfloat16 | [25 300] | FALSE | none | fwd | 97375 | 9832 | 9.903885273 |
| bfloat16 | [25 300] | FALSE | none | bwd | 107343 | 12179 | 8.813777814 |
| bfloat16 | [25 300] | FALSE | sum | fwd | 113568 | 31703 | 3.582247737 |
| bfloat16 | [25 300] | FALSE | sum | bwd | 106911 | 13353 | 8.00651539 |
| bfloat16 | [25 300] | FALSE | mean | fwd | 111935 | 44363 | 2.523161193 |
| bfloat16 | [25 300] | FALSE | mean | bwd | 146111 | 13478 | 10.84070337 |
| bfloat16 | [25 100 100] | TRUE | none | fwd | 99871 | 16500 | 6.052787879 |
| bfloat16 | [25 100 100] | TRUE | none | bwd | 114143 | 24324 | 4.692608124 |
| bfloat16 | [25 100 100] | TRUE | sum | fwd | 117183 | 42780 | 2.739200561 |
| bfloat16 | [25 100 100] | TRUE | sum | bwd | 111823 | 24235 | 4.614111822 |
| bfloat16 | [25 100 100] | TRUE | mean | fwd | 117839 | 40647 | 2.899082343 |
| bfloat16 | [25 100 100] | TRUE | mean | bwd | 119551 | 24324 | 4.914939977 |
| bfloat16 | [25 100 100] | FALSE | none | fwd | 100527 | 20323 | 4.946464597 |
| bfloat16 | [25 100 100] | FALSE | none | bwd | 114286 | 28004 | 4.081059849 |
| bfloat16 | [25 100 100] | FALSE | sum | fwd | 116926 | 44896 | 2.604374555 |
| bfloat16 | [25 100 100] | FALSE | sum | bwd | 112799 | 28929 | 3.899166926 |
| bfloat16 | [25 100 100] | FALSE | mean | fwd | 122095 | 45856 | 2.662574145 |
| bfloat16 | [25 100 100] | FALSE | mean | bwd | 218494 | 28769 | 7.594772151 |
| bfloat16 | [100 20 20 20] | TRUE | none | fwd | 158623 | 41731 | 3.801083128 |
| bfloat16 | [100 20 20 20] | TRUE | none | bwd | 174159 | 66179 | 2.631635413 |
| bfloat16 | [100 20 20 20] | TRUE | sum | fwd | 177903 | 70447 | 2.525345295 |
| bfloat16 | [100 20 20 20] | TRUE | sum | bwd | 176190 | 66073 | 2.666596038 |
| bfloat16 | [100 20 20 20] | TRUE | mean | fwd | 173326 | 69309 | 2.500771906 |
| bfloat16 | [100 20 20 20] | TRUE | mean | bwd | 181471 | 66126 | 2.744321447 |
| bfloat16 | [100 20 20 20] | FALSE | none | fwd | 159102 | 50799 | 3.131990787 |
| bfloat16 | [100 20 20 20] | FALSE | none | bwd | 175567 | 69558 | 2.524037494 |
| bfloat16 | [100 20 20 20] | FALSE | sum | fwd | 175678 | 78714 | 2.231852021 |
| bfloat16 | [100 20 20 20] | FALSE | sum | bwd | 177918 | 70269 | 2.53195577 |
| bfloat16 | [100 20 20 20] | FALSE | mean | fwd | 180158 | 79141 | 2.276418039 |
| bfloat16 | [100 20 20 20] | FALSE | mean | bwd | 481308 | 70198 | 6.856434656 |
| bfloat16 | [100 10 10 10 10] | TRUE | none | fwd | 207758 | 51154 | 4.061422372 |
| bfloat16 | [100 10 10 10 10] | TRUE | none | bwd | 226318 | 81079 | 2.791326977 |
| bfloat16 | [100 10 10 10 10] | TRUE | sum | fwd | 222222 | 78536 | 2.829555872 |
| bfloat16 | [100 10 10 10 10] | TRUE | sum | bwd | 225198 | 80421 | 2.800238744 |
| bfloat16 | [100 10 10 10 10] | TRUE | mean | fwd | 223310 | 79212 | 2.819143564 |
| bfloat16 | [100 10 10 10 10] | TRUE | mean | bwd | 235326 | 80474 | 2.924248826 |
| bfloat16 | [100 10 10 10 10] | FALSE | none | fwd | 209134 | 62000 | 3.373129032 |
| bfloat16 | [100 10 10 10 10] | FALSE | none | bwd | 226510 | 84795 | 2.671265994 |
| bfloat16 | [100 10 10 10 10] | FALSE | sum | fwd | 224717 | 91516 | 2.455494121 |
| bfloat16 | [100 10 10 10 10] | FALSE | sum | bwd | 228014 | 85399 | 2.669984426 |
| bfloat16 | [100 10 10 10 10] | FALSE | mean | fwd | 232046 | 91729 | 2.529690719 |
| bfloat16 | [100 10 10 10 10] | FALSE | mean | bwd | 577627 | 85328 | 6.769489499 |
| bfloat16 | [2000 3000] | TRUE | none | fwd | 903896 | 282298 | 3.201921374 |
| bfloat16 | [2000 3000] | TRUE | none | bwd | 973383 | 461968 | 2.107035552 |
| bfloat16 | [2000 3000] | TRUE | sum | fwd | 939544 | 351623 | 2.672020886 |
| bfloat16 | [2000 3000] | TRUE | sum | bwd | 957976 | 456100 | 2.100363955 |
| bfloat16 | [2000 3000] | TRUE | mean | fwd | 940568 | 350253 | 2.685395985 |
| bfloat16 | [2000 3000] | TRUE | mean | bwd | 996151 | 456384 | 2.1827036 |
| bfloat16 | [2000 3000] | FALSE | none | fwd | 920776 | 322158 | 2.858150349 |
| bfloat16 | [2000 3000] | FALSE | none | bwd | 961591 | 472595 | 2.034704134 |
| bfloat16 | [2000 3000] | FALSE | sum | fwd | 940984 | 392744 | 2.395922026 |
| bfloat16 | [2000 3000] | FALSE | sum | bwd | 957048 | 467652 | 2.046496112 |
| bfloat16 | [2000 3000] | FALSE | mean | fwd | 937320 | 393651 | 2.381093913 |
| bfloat16 | [2000 3000] | FALSE | mean | bwd | 2838903 | 468060 | 6.065254455 |
| bfloat16 | [25 1000 1000] | TRUE | none | fwd | 3488673 | 1160440 | 3.00633639 |
| bfloat16 | [25 1000 1000] | TRUE | none | bwd | 3706511 | 1906140 | 1.944511421 |
| bfloat16 | [25 1000 1000] | TRUE | sum | fwd | 3560448 | 1380640 | 2.578838799 |
| bfloat16 | [25 1000 1000] | TRUE | sum | bwd | 3698463 | 1883710 | 1.963392985 |
| bfloat16 | [25 1000 1000] | TRUE | mean | fwd | 3566721 | 1381190 | 2.582353623 |
| bfloat16 | [25 1000 1000] | TRUE | mean | bwd | 3855262 | 1883640 | 2.046708501 |
| bfloat16 | [25 1000 1000] | FALSE | none | fwd | 3495441 | 3502480 | 0.997990281 |
| bfloat16 | [25 1000 1000] | FALSE | none | bwd | 3705151 | 3596210 | 1.030293281 |
| bfloat16 | [25 1000 1000] | FALSE | sum | fwd | 3544400 | 3716430 | 0.953710954 |
| bfloat16 | [25 1000 1000] | FALSE | sum | bwd | 3719167 | 3544240 | 1.049355292 |
| bfloat16 | [25 1000 1000] | FALSE | mean | fwd | 3573984 | 3727210 | 0.958889894 |
| bfloat16 | [25 1000 1000] | FALSE | mean | bwd | 26678548 | 3542150 | 7.531738633 |
| bfloat16 | [10 100 100 100] | TRUE | none | fwd | 1467459 | 467714 | 3.137513523 |
| bfloat16 | [10 100 100 100] | TRUE | none | bwd | 1542386 | 764625 | 2.017179663 |
| bfloat16 | [10 100 100 100] | TRUE | sum | fwd | 1480723 | 567615 | 2.608674894 |
| bfloat16 | [10 100 100 100] | TRUE | sum | bwd | 1528754 | 756446 | 2.020969111 |
| bfloat16 | [10 100 100 100] | TRUE | mean | fwd | 1493827 | 568237 | 2.628880203 |
| bfloat16 | [10 100 100 100] | TRUE | mean | bwd | 1591858 | 756339 | 2.104688506 |
| bfloat16 | [10 100 100 100] | FALSE | none | fwd | 1450499 | 1137090 | 1.275623741 |
| bfloat16 | [10 100 100 100] | FALSE | none | bwd | 1534466 | 1223550 | 1.254109763 |
| bfloat16 | [10 100 100 100] | FALSE | sum | fwd | 1482163 | 1237170 | 1.198026949 |
| bfloat16 | [10 100 100 100] | FALSE | sum | bwd | 1526034 | 1209820 | 1.261372766 |
| bfloat16 | [10 100 100 100] | FALSE | mean | fwd | 1484323 | 1235710 | 1.201190409 |
| bfloat16 | [10 100 100 100] | FALSE | mean | bwd | 9208158 | 1210750 | 7.605333884 |
| float32 | [10 100 100 100] | FALSE | mean | bwd | 14042708 | 1356220 | 10.35429945 |
It seems a part of this PR included into https://github.com/ROCm/MIOpen/pull/3146 May I ask you to collaborate with @littlecutebird and fix all the common parts of both PRs?
I have updated my code following comments in #3146 . Please take another look at my PR.
@junliume can you take a look at Windows build state, plz. I added MIOPEN_INTERNALS_EXPORT but it still fail.
@iq136boy can you help me check the log in windows build stage, please?
@iq136boy can you help me check the log in windows build stage, please?
[ 95%] Linking CXX executable ..\..\bin\test_sigmoid_focal_loss.exe
lld-link: error: undefined symbol: __declspec(dllimport) enum miopenStatus_t __cdecl miopen::SigmoidFocalLossForward(struct miopen::Handle &, void *, unsigned __int64, struct miopen::TensorDescriptor const &, void const *, struct miopen::TensorDescriptor const &, void const *, struct miopen::TensorDescriptor const &, void *, float, float, enum miopenLossReductionMode_t)
lld-link: error: undefined symbol: __declspec(dllimport) enum miopenStatus_t __cdecl miopen::SigmoidFocalLossBackward(struct miopen::Handle &, struct miopen::TensorDescriptor const &, void const *, struct miopen::TensorDescriptor const &, void const *, struct miopen::TensorDescriptor const &, void const *, struct miopen::TensorDescriptor const &, void *, struct miopen::TensorDescriptor const &, void *, float, float, enum miopenLossReductionMode_t)
lld-link: error: undefined symbol: __declspec(dllimport) unsigned __int64 __cdecl miopen::GetSigmoidFocalLossForwardWorkspaceSize(struct miopen::Handle &, struct miopen::TensorDescriptor const &, struct miopen::TensorDescriptor const &, struct miopen::TensorDescriptor const &, enum miopenLossReductionMode_t)
>>> referenced by C:\home\jenkins\agent\workspace\UIF2_MIOpen_PR-3143\MIOpen\test\gtest\sigmoid_focal_loss.hpp:299
>>> CMakeFiles/test_sigmoid_focal_loss.dir/sigmoid_focal_loss.cpp.obj:(protected: virtual void __cdecl SigmoidFocalLossFwdTest<float>::SetUp(void))
>>> referenced by C:\home\jenkins\agent\workspace\UIF2_MIOpen_PR-3143\MIOpen\test\gtest\sigmoid_focal_loss.hpp:299
>>> CMakeFiles/test_sigmoid_focal_loss.dir/sigmoid_focal_loss.cpp.obj:(protected: virtual void __cdecl SigmoidFocalLossFwdTest<class half_float::half>::SetUp(void))
>>> referenced by C:\home\jenkins\agent\workspace\UIF2_MIOpen_PR-3143\MIOpen\test\gtest\sigmoid_focal_loss.hpp:299
>>> CMakeFiles/test_sigmoid_focal_loss.dir/sigmoid_focal_loss.cpp.obj:(protected: virtual void __cdecl SigmoidFocalLossFwdTest<class bfloat16>::SetUp(void))
@iq136boy For this PR only, can you and your colleagues give comments about documentation problem? It's great to learn by examples. Please provide us the parts that you guys belive that it's importance to have or to be mentioned in documents.
MIOpen is moving to the new monorepo setup and all older unmerged PR's are being closed. Please re-open this as part of the new repo if these changes are still needed.