MIOpen
MIOpen copied to clipboard
Implement SoftMarginLoss
- Added SoftMarginLoss operation for both forward and backward. Compared to ROCm, it is better for all cases.
- New API is guarded by MIOPEN_BETA_API macro. Added 2 kernels: SoftMarginLossForward5d, SoftMarginLossBackward5d
- Added driver test and gtest for SoftMarginLoss.
- Compared to ROCm:
Unreduced:
| type | Forward | Backward |
|---|---|---|
| float32 | 2.50 | 3.30 |
| float16 | 2.46 | 3.12 |
| bfloat16 | 2.51 | 3.30 |
fp32 forward
| input_size | stride_size | cont | ROCm | MIOpen | Improvement |
|---|---|---|---|---|---|
| [256 4 8732] | [34928 8732 1] | TRUE | 312830 | 213670 | 1.46 |
| [32 80 870] | [69600 1 80] | FALSE | 126889 | 71484 | 1.78 |
| [32 80 870] | [69600 870 1] | TRUE | 89579 | 57439 | 1.56 |
| [4 182403 91] | [16598673 91 1] | TRUE | 2231121 | 1555260 | 1.43 |
| [1534680] | [1] | TRUE | 64798 | 41155 | 1.57 |
| [16 1 512 512] | [262144 262144 512 1] | TRUE | 156073 | 103217 | 1.51 |
| [2 3 160 160] | [6528000 2176000 13600 85] | FALSE | 33934 | 19786 | 1.72 |
| [2 3 80 80] | [1632000 544000 6800 85] | FALSE | 25663 | 9564 | 2.68 |
| [32756 80] | [85 1] | FALSE | 109723 | 66630 | 1.65 |
| [64 3 80 80] | [1632000 544000 6800 85] | FALSE | 153257 | 119786 | 1.28 |
| [64 3 40 40] | [408000 136000 3400 85] | FALSE | 49854 | 32924 | 1.51 |
| [22311 80] | [85 1] | FALSE | 79085 | 47395 | 1.67 |
| [64 3 20 20] | [102000 34000 1700 85] | FALSE | 28703 | 12035 | 2.38 |
| [8 4] | [4 1] | TRUE | 17823 | 10631 | 1.68 |
| [56 4] | [4 1] | TRUE | 15104 | 11306 | 1.34 |
| [131 4] | [4 1] | TRUE | 20512 | 9511 | 2.16 |
| [10000] | [1] | TRUE | 21167 | 7306 | 2.90 |
| [200 50] | [50 1] | TRUE | 23648 | 7288 | 3.24 |
| [20 50 10] | [500 10 1] | TRUE | 20128 | 7235 | 2.78 |
| [4 25 4 25] | [2500 100 25 1] | TRUE | 23567 | 7324 | 3.22 |
| [12 3 4 5 6] | [360 120 30 6 1] | TRUE | 21584 | 8142 | 2.65 |
| [10000] | [3] | FALSE | 18880 | 7235 | 2.61 |
| [200 50] | [1 200] | FALSE | 28879 | 7253 | 3.98 |
| [200 50] | [505 1] | FALSE | 29087 | 7324 | 3.97 |
| [20 50 10] | [1 20 1000] | FALSE | 28447 | 7608 | 3.74 |
| [20 50 10] | [7575 15 1] | FALSE | 27359 | 7271 | 3.76 |
| [4 25 4 25] | [1 16 4 400] | FALSE | 27007 | 7324 | 3.69 |
| [4 25 4 25] | [5859 217 31 1] | FALSE | 29456 | 7093 | 4.15 |
| [12 3 4 5 6] | [360 120 6 24 1] | FALSE | 24496 | 8017 | 3.06 |
| [12 3 4 5 6] | [5760 960 120 12 1] | FALSE | 31439 | 8391 | 3.75 |
fp32 backward
| input_size | stride_size | cont | ROCm | MIOpen | Improvement |
|---|---|---|---|---|---|
| [256 4 8732] | [34928 8732 1] | TRUE | 553089 | 232674 | 2.38 |
| [32 80 870] | [69600 1 80] | FALSE | 299391 | 73208 | 4.09 |
| [32 80 870] | [69600 870 1] | TRUE | 161303 | 61492 | 2.62 |
| [4 182403 91] | [16598673 91 1] | TRUE | 4006075 | 1703260 | 2.35 |
| [1534680] | [1] | TRUE | 116011 | 43999 | 2.64 |
| [16 1 512 512] | [262144 262144 512 1] | TRUE | 273716 | 111466 | 2.46 |
| [2 3 160 160] | [6528000 2176000 13600 85] | FALSE | 46222 | 18808 | 2.46 |
| [2 3 80 80] | [1632000 544000 6800 85] | FALSE | 31038 | 9493 | 3.27 |
| [32756 80] | [85 1] | FALSE | 190648 | 71661 | 2.66 |
| [64 3 80 80] | [1632000 544000 6800 85] | FALSE | 196184 | 123004 | 1.59 |
| [64 3 40 40] | [408000 136000 3400 85] | FALSE | 64541 | 33279 | 1.94 |
| [22311 80] | [85 1] | FALSE | 139178 | 50595 | 2.75 |
| [64 3 20 20] | [102000 34000 1700 85] | FALSE | 36431 | 12337 | 2.95 |
| [8 4] | [4 1] | TRUE | 27519 | 12622 | 2.18 |
| [56 4] | [4 1] | TRUE | 23840 | 11360 | 2.10 |
| [131 4] | [4 1] | TRUE | 27472 | 10435 | 2.63 |
| [10000] | [1] | TRUE | 27423 | 7786 | 3.52 |
| [200 50] | [50 1] | TRUE | 28080 | 8302 | 3.38 |
| [20 50 10] | [500 10 1] | TRUE | 28000 | 7999 | 3.50 |
| [4 25 4 25] | [2500 100 25 1] | TRUE | 27343 | 8337 | 3.28 |
| [12 3 4 5 6] | [360 120 30 6 1] | TRUE | 27247 | 8888 | 3.07 |
| [10000] | [3] | FALSE | 28352 | 7999 | 3.54 |
| [200 50] | [1 200] | FALSE | 42255 | 8373 | 5.05 |
| [200 50] | [505 1] | FALSE | 35823 | 7751 | 4.62 |
| [20 50 10] | [1 20 1000] | FALSE | 44127 | 8213 | 5.37 |
| [20 50 10] | [7575 15 1] | FALSE | 32879 | 8160 | 4.03 |
| [4 25 4 25] | [1 16 4 400] | FALSE | 42688 | 8284 | 5.15 |
| [4 25 4 25] | [5859 217 31 1] | FALSE | 33519 | 8124 | 4.13 |
| [12 3 4 5 6] | [360 120 6 24 1] | FALSE | 46415 | 8871 | 5.23 |
| [12 3 4 5 6] | [5760 960 120 12 1] | FALSE | 35791 | 8640 | 4.14 |
fp16 forward
| input_size | stride_size | cont | ROCm | MIOpen | Improvement |
|---|---|---|---|---|---|
| [256 4 8732] | [34928 8732 1] | TRUE | 245058 | 214701 | 1.14 |
| [32 80 870] | [69600 1 80] | FALSE | 111946 | 67626 | 1.66 |
| [32 80 870] | [69600 870 1] | TRUE | 71116 | 57030 | 1.25 |
| [4 182403 91] | [16598673 91 1] | TRUE | 1704623 | 1565680 | 1.09 |
| [1534680] | [1] | TRUE | 51181 | 40746 | 1.26 |
| [16 1 512 512] | [262144 262144 512 1] | TRUE | 122427 | 103182 | 1.19 |
| [2 3 160 160] | [6528000 2176000 13600 85] | FALSE | 29743 | 17724 | 1.68 |
| [2 3 80 80] | [1632000 544000 6800 85] | FALSE | 23279 | 9422 | 2.47 |
| [32756 80] | [85 1] | FALSE | 87740 | 66346 | 1.32 |
| [64 3 80 80] | [1632000 544000 6800 85] | FALSE | 120027 | 92443 | 1.30 |
| [64 3 40 40] | [408000 136000 3400 85] | FALSE | 41742 | 28586 | 1.46 |
| [22311 80] | [85 1] | FALSE | 63229 | 47021 | 1.34 |
| [64 3 20 20] | [102000 34000 1700 85] | FALSE | 22895 | 11395 | 2.01 |
| [8 4] | [4 1] | TRUE | 18463 | 9226 | 2.00 |
| [56 4] | [4 1] | TRUE | 16768 | 10080 | 1.66 |
| [131 4] | [4 1] | TRUE | 22543 | 9173 | 2.46 |
| [10000] | [1] | TRUE | 23711 | 7182 | 3.30 |
| [200 50] | [50 1] | TRUE | 24336 | 7182 | 3.39 |
| [20 50 10] | [500 10 1] | TRUE | 23520 | 7360 | 3.20 |
| [4 25 4 25] | [2500 100 25 1] | TRUE | 25071 | 7199 | 3.48 |
| [12 3 4 5 6] | [360 120 30 6 1] | TRUE | 24208 | 8071 | 3.00 |
| [10000] | [3] | FALSE | 20176 | 7235 | 2.79 |
| [200 50] | [1 200] | FALSE | 26015 | 7395 | 3.52 |
| [200 50] | [505 1] | FALSE | 26991 | 7235 | 3.73 |
| [20 50 10] | [1 20 1000] | FALSE | 26127 | 7235 | 3.61 |
| [20 50 10] | [7575 15 1] | FALSE | 26064 | 7057 | 3.69 |
| [4 25 4 25] | [1 16 4 400] | FALSE | 27296 | 7288 | 3.75 |
| [4 25 4 25] | [5859 217 31 1] | FALSE | 26624 | 7146 | 3.73 |
| [12 3 4 5 6] | [360 120 6 24 1] | FALSE | 26720 | 7840 | 3.41 |
| [12 3 4 5 6] | [5760 960 120 12 1] | FALSE | 32159 | 8320 | 3.87 |
fp16 backward
| input_size | stride_size | cont | ROCm | MIOpen | Improvement |
|---|---|---|---|---|---|
| [256 4 8732] | [34928 8732 1] | TRUE | 337901 | 233368 | 1.45 |
| [32 80 870] | [69600 1 80] | FALSE | 249874 | 69777 | 3.58 |
| [32 80 870] | [69600 870 1] | TRUE | 103578 | 61635 | 1.68 |
| [4 182403 91] | [16598673 91 1] | TRUE | 2320300 | 1712220 | 1.36 |
| [1534680] | [1] | TRUE | 71181 | 43484 | 1.64 |
| [16 1 512 512] | [262144 262144 512 1] | TRUE | 175272 | 111946 | 1.57 |
| [2 3 160 160] | [6528000 2176000 13600 85] | FALSE | 42190 | 16746 | 2.52 |
| [2 3 80 80] | [1632000 544000 6800 85] | FALSE | 31822 | 9600 | 3.31 |
| [32756 80] | [85 1] | FALSE | 125770 | 71519 | 1.76 |
| [64 3 80 80] | [1632000 544000 6800 85] | FALSE | 139818 | 94755 | 1.48 |
| [64 3 40 40] | [408000 136000 3400 85] | FALSE | 55549 | 28142 | 1.97 |
| [22311 80] | [85 1] | FALSE | 85789 | 50524 | 1.70 |
| [64 3 20 20] | [102000 34000 1700 85] | FALSE | 35278 | 11413 | 3.09 |
| [8 4] | [4 1] | TRUE | 28079 | 12924 | 2.17 |
| [56 4] | [4 1] | TRUE | 23776 | 12195 | 1.95 |
| [131 4] | [4 1] | TRUE | 28975 | 10453 | 2.77 |
| [10000] | [1] | TRUE | 30447 | 8088 | 3.76 |
| [200 50] | [50 1] | TRUE | 29456 | 8231 | 3.58 |
| [20 50 10] | [500 10 1] | TRUE | 29728 | 8177 | 3.64 |
| [4 25 4 25] | [2500 100 25 1] | TRUE | 29664 | 8106 | 3.66 |
| [12 3 4 5 6] | [360 120 30 6 1] | TRUE | 30063 | 8871 | 3.39 |
| [10000] | [3] | FALSE | 29616 | 8160 | 3.63 |
| [200 50] | [1 200] | FALSE | 42367 | 8284 | 5.11 |
| [200 50] | [505 1] | FALSE | 35807 | 8071 | 4.44 |
| [20 50 10] | [1 20 1000] | FALSE | 43918 | 8213 | 5.35 |
| [20 50 10] | [7575 15 1] | FALSE | 34831 | 7964 | 4.37 |
| [4 25 4 25] | [1 16 4 400] | FALSE | 43135 | 8302 | 5.20 |
| [4 25 4 25] | [5859 217 31 1] | FALSE | 33759 | 8017 | 4.21 |
| [12 3 4 5 6] | [360 120 6 24 1] | FALSE | 45439 | 8515 | 5.34 |
| [12 3 4 5 6] | [5760 960 120 12 1] | FALSE | 36959 | 9155 | 4.04 |
bfp16 forward
| input_size | stride_size | cont | ROCm | MIOpen | Improvement |
|---|---|---|---|---|---|
| [256 4 8732] | [34928 8732 1] | TRUE | 270257 | 219466 | 1.23 |
| [32 80 870] | [69600 1 80] | FALSE | 116921 | 67590 | 1.73 |
| [32 80 870] | [69600 870 1] | TRUE | 76108 | 59324 | 1.28 |
| [4 182403 91] | [16598673 91 1] | TRUE | 1896068 | 1595560 | 1.19 |
| [1534680] | [1] | TRUE | 54781 | 41937 | 1.31 |
| [16 1 512 512] | [262144 262144 512 1] | TRUE | 132810 | 105831 | 1.25 |
| [2 3 160 160] | [6528000 2176000 13600 85] | FALSE | 30495 | 17475 | 1.75 |
| [2 3 80 80] | [1632000 544000 6800 85] | FALSE | 23135 | 9333 | 2.48 |
| [32756 80] | [85 1] | FALSE | 94156 | 68177 | 1.38 |
| [64 3 80 80] | [1632000 544000 6800 85] | FALSE | 122075 | 91715 | 1.33 |
| [64 3 40 40] | [408000 136000 3400 85] | FALSE | 41950 | 28320 | 1.48 |
| [22311 80] | [85 1] | FALSE | 67629 | 48124 | 1.41 |
| [64 3 20 20] | [102000 34000 1700 85] | FALSE | 24447 | 11431 | 2.14 |
| [8 4] | [4 1] | TRUE | 20224 | 11200 | 1.81 |
| [56 4] | [4 1] | TRUE | 16368 | 10542 | 1.55 |
| [131 4] | [4 1] | TRUE | 24799 | 10115 | 2.45 |
| [10000] | [1] | TRUE | 25344 | 7555 | 3.35 |
| [200 50] | [50 1] | TRUE | 26704 | 7199 | 3.71 |
| [20 50 10] | [500 10 1] | TRUE | 23472 | 7288 | 3.22 |
| [4 25 4 25] | [2500 100 25 1] | TRUE | 24736 | 7768 | 3.18 |
| [12 3 4 5 6] | [360 120 30 6 1] | TRUE | 23519 | 8533 | 2.76 |
| [10000] | [3] | FALSE | 21263 | 7235 | 2.94 |
| [200 50] | [1 200] | FALSE | 28783 | 7448 | 3.86 |
| [200 50] | [505 1] | FALSE | 28639 | 7342 | 3.90 |
| [20 50 10] | [1 20 1000] | FALSE | 28143 | 7235 | 3.89 |
| [20 50 10] | [7575 15 1] | FALSE | 27856 | 7342 | 3.79 |
| [4 25 4 25] | [1 16 4 400] | FALSE | 27952 | 7431 | 3.76 |
| [4 25 4 25] | [5859 217 31 1] | FALSE | 27439 | 7484 | 3.67 |
| [12 3 4 5 6] | [360 120 6 24 1] | FALSE | 26911 | 8213 | 3.28 |
| [12 3 4 5 6] | [5760 960 120 12 1] | FALSE | 33871 | 8160 | 4.15 |
bfp16 backward
| input_size | stride_size | cont | ROCm | MIOpen | Improvement |
|---|---|---|---|---|---|
| [256 4 8732] | [34928 8732 1] | TRUE | 387226 | 240141 | 1.61 |
| [32 80 870] | [69600 1 80] | FALSE | 256274 | 71395 | 3.59 |
| [32 80 870] | [69600 870 1] | TRUE | 112138 | 64017 | 1.75 |
| [4 182403 91] | [16598673 91 1] | TRUE | 2719733 | 1754320 | 1.55 |
| [1534680] | [1] | TRUE | 79645 | 45191 | 1.76 |
| [16 1 512 512] | [262144 262144 512 1] | TRUE | 197064 | 115057 | 1.71 |
| [2 3 160 160] | [6528000 2176000 13600 85] | FALSE | 46255 | 18737 | 2.47 |
| [2 3 80 80] | [1632000 544000 6800 85] | FALSE | 35343 | 10453 | 3.38 |
| [32756 80] | [85 1] | FALSE | 139306 | 73848 | 1.89 |
| [64 3 80 80] | [1632000 544000 6800 85] | FALSE | 148474 | 96479 | 1.54 |
| [64 3 40 40] | [408000 136000 3400 85] | FALSE | 59357 | 29706 | 2.00 |
| [22311 80] | [85 1] | FALSE | 98076 | 51875 | 1.89 |
| [64 3 20 20] | [102000 34000 1700 85] | FALSE | 39998 | 13386 | 2.99 |
| [8 4] | [4 1] | TRUE | 30128 | 11751 | 2.56 |
| [56 4] | [4 1] | TRUE | 25055 | 12373 | 2.02 |
| [131 4] | [4 1] | TRUE | 32559 | 10506 | 3.10 |
| [10000] | [1] | TRUE | 34687 | 8568 | 4.05 |
| [200 50] | [50 1] | TRUE | 34544 | 8408 | 4.11 |
| [20 50 10] | [500 10 1] | TRUE | 33648 | 8639 | 3.89 |
| [4 25 4 25] | [2500 100 25 1] | TRUE | 34335 | 8764 | 3.92 |
| [12 3 4 5 6] | [360 120 30 6 1] | TRUE | 33983 | 9173 | 3.70 |
| [10000] | [3] | FALSE | 31600 | 8533 | 3.70 |
| [200 50] | [1 200] | FALSE | 48271 | 8302 | 5.81 |
| [200 50] | [505 1] | FALSE | 40383 | 8533 | 4.73 |
| [20 50 10] | [1 20 1000] | FALSE | 47374 | 8604 | 5.51 |
| [20 50 10] | [7575 15 1] | FALSE | 38431 | 8444 | 4.55 |
| [4 25 4 25] | [1 16 4 400] | FALSE | 46383 | 8515 | 5.45 |
| [4 25 4 25] | [5859 217 31 1] | FALSE | 36944 | 8444 | 4.38 |
| [12 3 4 5 6] | [360 120 6 24 1] | FALSE | 49263 | 9475 | 5.20 |
| [12 3 4 5 6] | [5760 960 120 12 1] | FALSE | 40592 | 9475 | 4.28 |
Reduced:
| type | Forward | Backward |
|---|---|---|
| float32 | 3.26 | 2.88 |
| float16 | 3.04 | 2.66 |
| bfloat16 | 3.10 | 2.84 |
fp32 forward
| input_size | stride_size | cont | ROCm | MIOpen | Improvement |
|---|---|---|---|---|---|
| [256 4 8732] | [34928 8732 1] | TRUE | 369111 | 301989 | 1.22 |
| [32 80 870] | [69600 1 80] | FALSE | 161182 | 106168 | 1.52 |
| [32 80 870] | [69600 870 1] | TRUE | 122654 | 91341 | 1.34 |
| [4 182403 91] | [16598673 91 1] | TRUE | 2476240 | 2107970 | 1.17 |
| [1534680] | [1] | TRUE | 170494 | 69972 | 2.44 |
| [16 1 512 512] | [262144 262144 512 1] | TRUE | 200662 | 152301 | 1.32 |
| [2 3 160 160] | [6528000 2176000 13600 85] | FALSE | 105519 | 37333 | 2.83 |
| [2 3 80 80] | [1632000 544000 6800 85] | FALSE | 103921 | 24408 | 4.26 |
| [32756 80] | [85 1] | FALSE | 148782 | 105724 | 1.41 |
| [64 3 80 80] | [1632000 544000 6800 85] | FALSE | 178350 | 145973 | 1.22 |
| [64 3 40 40] | [408000 136000 3400 85] | FALSE | 89535 | 51928 | 1.72 |
| [22311 80] | [85 1] | FALSE | 109711 | 78826 | 1.39 |
| [64 3 20 20] | [102000 34000 1700 85] | FALSE | 71823 | 30346 | 2.37 |
| [8 4] | [4 1] | TRUE | 99951 | 20497 | 4.88 |
| [56 4] | [4 1] | TRUE | 95135 | 20444 | 4.65 |
| [131 4] | [4 1] | TRUE | 101231 | 26737 | 3.79 |
| [10000] | [1] | TRUE | 120670 | 24017 | 5.02 |
| [200 50] | [50 1] | TRUE | 107343 | 23093 | 4.65 |
| [20 50 10] | [500 10 1] | TRUE | 107951 | 23466 | 4.60 |
| [4 25 4 25] | [2500 100 25 1] | TRUE | 116206 | 22630 | 5.14 |
| [12 3 4 5 6] | [360 120 30 6 1] | TRUE | 99311 | 24639 | 4.03 |
| [10000] | [3] | FALSE | 114110 | 22915 | 4.98 |
| [200 50] | [1 200] | FALSE | 101343 | 22986 | 4.41 |
| [200 50] | [505 1] | FALSE | 107390 | 23128 | 4.64 |
| [20 50 10] | [1 20 1000] | FALSE | 107598 | 26453 | 4.07 |
| [20 50 10] | [7575 15 1] | FALSE | 95119 | 22435 | 4.24 |
| [4 25 4 25] | [1 16 4 400] | FALSE | 110078 | 25937 | 4.24 |
| [4 25 4 25] | [5859 217 31 1] | FALSE | 85247 | 24053 | 3.54 |
| [12 3 4 5 6] | [360 120 6 24 1] | FALSE | 85183 | 32764 | 2.60 |
| [12 3 4 5 6] | [5760 960 120 12 1] | FALSE | 99375 | 24888 | 3.99 |
fp32 backward
| input_size | stride_size | cont | ROCm | MIOpen | Improvement |
|---|---|---|---|---|---|
| [256 4 8732] | [34928 8732 1] | TRUE | 528898 | 248549 | 2.13 |
| [32 80 870] | [69600 1 80] | FALSE | 252066 | 75057 | 3.36 |
| [32 80 870] | [69600 870 1] | TRUE | 153719 | 65421 | 2.35 |
| [4 182403 91] | [16598673 91 1] | TRUE | 3838133 | 1819290 | 2.11 |
| [1534680] | [1] | TRUE | 111195 | 46666 | 2.38 |
| [16 1 512 512] | [262144 262144 512 1] | TRUE | 262661 | 118595 | 2.21 |
| [2 3 160 160] | [6528000 2176000 13600 85] | FALSE | 45726 | 18862 | 2.42 |
| [2 3 80 80] | [1632000 544000 6800 85] | FALSE | 30014 | 9599 | 3.13 |
| [32756 80] | [85 1] | FALSE | 184584 | 76497 | 2.41 |
| [64 3 80 80] | [1632000 544000 6800 85] | FALSE | 192808 | 123839 | 1.56 |
| [64 3 40 40] | [408000 136000 3400 85] | FALSE | 63757 | 33439 | 1.91 |
| [22311 80] | [85 1] | FALSE | 132763 | 53919 | 2.46 |
| [64 3 20 20] | [102000 34000 1700 85] | FALSE | 34639 | 12373 | 2.80 |
| [8 4] | [4 1] | TRUE | 27759 | 12533 | 2.21 |
| [56 4] | [4 1] | TRUE | 24015 | 12871 | 1.87 |
| [131 4] | [4 1] | TRUE | 24752 | 10168 | 2.43 |
| [10000] | [1] | TRUE | 24687 | 8515 | 2.90 |
| [200 50] | [50 1] | TRUE | 26800 | 8586 | 3.12 |
| [20 50 10] | [500 10 1] | TRUE | 24624 | 8586 | 2.87 |
| [4 25 4 25] | [2500 100 25 1] | TRUE | 24831 | 8231 | 3.02 |
| [12 3 4 5 6] | [360 120 30 6 1] | TRUE | 26736 | 9315 | 2.87 |
| [10000] | [3] | FALSE | 28815 | 8622 | 3.34 |
| [200 50] | [1 200] | FALSE | 36079 | 8782 | 4.11 |
| [200 50] | [505 1] | FALSE | 32527 | 8924 | 3.64 |
| [20 50 10] | [1 20 1000] | FALSE | 37264 | 8320 | 4.48 |
| [20 50 10] | [7575 15 1] | FALSE | 30239 | 8284 | 3.65 |
| [4 25 4 25] | [1 16 4 400] | FALSE | 36687 | 8675 | 4.23 |
| [4 25 4 25] | [5859 217 31 1] | FALSE | 28655 | 8391 | 3.41 |
| [12 3 4 5 6] | [360 120 6 24 1] | FALSE | 37087 | 9653 | 3.84 |
| [12 3 4 5 6] | [5760 960 120 12 1] | FALSE | 32399 | 10079 | 3.21 |
fp16 forward
| input_size | stride_size | cont | ROCm | MIOpen | Improvement |
|---|---|---|---|---|---|
| [256 4 8732] | [34928 8732 1] | TRUE | 291221 | 302291 | 0.96 |
| [32 80 870] | [69600 1 80] | FALSE | 141342 | 103821 | 1.36 |
| [32 80 870] | [69600 870 1] | TRUE | 133998 | 91519 | 1.46 |
| [4 182403 91] | [16598673 91 1] | TRUE | 1853128 | 2113120 | 0.88 |
| [1534680] | [1] | TRUE | 119247 | 69279 | 1.72 |
| [16 1 512 512] | [262144 262144 512 1] | TRUE | 158849 | 152177 | 1.04 |
| [2 3 160 160] | [6528000 2176000 13600 85] | FALSE | 79967 | 35786 | 2.23 |
| [2 3 80 80] | [1632000 544000 6800 85] | FALSE | 93090 | 26382 | 3.53 |
| [32756 80] | [85 1] | FALSE | 119166 | 103875 | 1.15 |
| [64 3 80 80] | [1632000 544000 6800 85] | FALSE | 146478 | 126950 | 1.15 |
| [64 3 40 40] | [408000 136000 3400 85] | FALSE | 105903 | 47875 | 2.21 |
| [22311 80] | [85 1] | FALSE | 90287 | 77706 | 1.16 |
| [64 3 20 20] | [102000 34000 1700 85] | FALSE | 64847 | 29884 | 2.17 |
| [8 4] | [4 1] | TRUE | 108239 | 20462 | 5.29 |
| [56 4] | [4 1] | TRUE | 96959 | 20888 | 4.64 |
| [131 4] | [4 1] | TRUE | 101678 | 26186 | 3.88 |
| [10000] | [1] | TRUE | 108911 | 25920 | 4.20 |
| [200 50] | [50 1] | TRUE | 122207 | 26417 | 4.63 |
| [20 50 10] | [500 10 1] | TRUE | 98911 | 23004 | 4.30 |
| [4 25 4 25] | [2500 100 25 1] | TRUE | 111774 | 30986 | 3.61 |
| [12 3 4 5 6] | [360 120 30 6 1] | TRUE | 95423 | 25564 | 3.73 |
| [10000] | [3] | FALSE | 111503 | 23128 | 4.82 |
| [200 50] | [1 200] | FALSE | 110559 | 23342 | 4.74 |
| [200 50] | [505 1] | FALSE | 98271 | 28551 | 3.44 |
| [20 50 10] | [1 20 1000] | FALSE | 104975 | 22933 | 4.58 |
| [20 50 10] | [7575 15 1] | FALSE | 100159 | 23893 | 4.19 |
| [4 25 4 25] | [1 16 4 400] | FALSE | 106479 | 26168 | 4.07 |
| [4 25 4 25] | [5859 217 31 1] | FALSE | 84959 | 28106 | 3.02 |
| [12 3 4 5 6] | [360 120 6 24 1] | FALSE | 99134 | 24568 | 4.04 |
| [12 3 4 5 6] | [5760 960 120 12 1] | FALSE | 92783 | 32141 | 2.89 |
fp16 backward
| input_size | stride_size | cont | ROCm | MIOpen | Improvement |
|---|---|---|---|---|---|
| [256 4 8732] | [34928 8732 1] | TRUE | 341916 | 250932 | 1.36 |
| [32 80 870] | [69600 1 80] | FALSE | 201044 | 72781 | 2.76 |
| [32 80 870] | [69600 870 1] | TRUE | 102250 | 65830 | 1.55 |
| [4 182403 91] | [16598673 91 1] | TRUE | 2344090 | 1837820 | 1.28 |
| [1534680] | [1] | TRUE | 72333 | 46844 | 1.54 |
| [16 1 512 512] | [262144 262144 512 1] | TRUE | 174856 | 120053 | 1.46 |
| [2 3 160 160] | [6528000 2176000 13600 85] | FALSE | 41646 | 18595 | 2.24 |
| [2 3 80 80] | [1632000 544000 6800 85] | FALSE | 30831 | 10666 | 2.89 |
| [32756 80] | [85 1] | FALSE | 125979 | 76852 | 1.64 |
| [64 3 80 80] | [1632000 544000 6800 85] | FALSE | 140906 | 96675 | 1.46 |
| [64 3 40 40] | [408000 136000 3400 85] | FALSE | 54558 | 29848 | 1.83 |
| [22311 80] | [85 1] | FALSE | 87149 | 53990 | 1.61 |
| [64 3 20 20] | [102000 34000 1700 85] | FALSE | 34863 | 13546 | 2.57 |
| [8 4] | [4 1] | TRUE | 25327 | 12391 | 2.04 |
| [56 4] | [4 1] | TRUE | 24303 | 13208 | 1.84 |
| [131 4] | [4 1] | TRUE | 27407 | 10560 | 2.60 |
| [10000] | [1] | TRUE | 28512 | 8444 | 3.38 |
| [200 50] | [50 1] | TRUE | 27120 | 8675 | 3.13 |
| [20 50 10] | [500 10 1] | TRUE | 27184 | 8533 | 3.19 |
| [4 25 4 25] | [2500 100 25 1] | TRUE | 27823 | 9511 | 2.93 |
| [12 3 4 5 6] | [360 120 30 6 1] | TRUE | 28655 | 9173 | 3.12 |
| [10000] | [3] | FALSE | 28463 | 8640 | 3.29 |
| [200 50] | [1 200] | FALSE | 36799 | 9262 | 3.97 |
| [200 50] | [505 1] | FALSE | 31359 | 8408 | 3.73 |
| [20 50 10] | [1 20 1000] | FALSE | 36351 | 8817 | 4.12 |
| [20 50 10] | [7575 15 1] | FALSE | 31007 | 8657 | 3.58 |
| [4 25 4 25] | [1 16 4 400] | FALSE | 36143 | 9102 | 3.97 |
| [4 25 4 25] | [5859 217 31 1] | FALSE | 29152 | 8942 | 3.26 |
| [12 3 4 5 6] | [360 120 6 24 1] | FALSE | 37519 | 9404 | 3.99 |
| [12 3 4 5 6] | [5760 960 120 12 1] | FALSE | 33167 | 9724 | 3.41 |
bfp16 forward
| input_size | stride_size | cont | ROCm | MIOpen | Improvement |
|---|---|---|---|---|---|
| [256 4 8732] | [34928 8732 1] | TRUE | 315541 | 302344 | 1.04 |
| [32 80 870] | [69600 1 80] | FALSE | 147134 | 104071 | 1.41 |
| [32 80 870] | [69600 870 1] | TRUE | 111966 | 93119 | 1.20 |
| [4 182403 91] | [16598673 91 1] | TRUE | 2060933 | 2116840 | 0.97 |
| [1534680] | [1] | TRUE | 104399 | 69652 | 1.50 |
| [16 1 512 512] | [262144 262144 512 1] | TRUE | 170487 | 152390 | 1.12 |
| [2 3 160 160] | [6528000 2176000 13600 85] | FALSE | 76767 | 35786 | 2.15 |
| [2 3 80 80] | [1632000 544000 6800 85] | FALSE | 97528 | 24871 | 3.92 |
| [32756 80] | [85 1] | FALSE | 128750 | 103768 | 1.24 |
| [64 3 80 80] | [1632000 544000 6800 85] | FALSE | 148734 | 126470 | 1.18 |
| [64 3 40 40] | [408000 136000 3400 85] | FALSE | 84079 | 52924 | 1.59 |
| [22311 80] | [85 1] | FALSE | 97359 | 77706 | 1.25 |
| [64 3 20 20] | [102000 34000 1700 85] | FALSE | 70831 | 34382 | 2.06 |
| [8 4] | [4 1] | TRUE | 102815 | 19822 | 5.19 |
| [56 4] | [4 1] | TRUE | 104142 | 19786 | 5.26 |
| [131 4] | [4 1] | TRUE | 121294 | 26364 | 4.60 |
| [10000] | [1] | TRUE | 101311 | 23679 | 4.28 |
| [200 50] | [50 1] | TRUE | 123118 | 33208 | 3.71 |
| [20 50 10] | [500 10 1] | TRUE | 99215 | 24159 | 4.11 |
| [4 25 4 25] | [2500 100 25 1] | TRUE | 110846 | 25173 | 4.40 |
| [12 3 4 5 6] | [360 120 30 6 1] | TRUE | 111215 | 35839 | 3.10 |
| [10000] | [3] | FALSE | 106463 | 25937 | 4.10 |
| [200 50] | [1 200] | FALSE | 111246 | 24017 | 4.63 |
| [200 50] | [505 1] | FALSE | 101679 | 23342 | 4.36 |
| [20 50 10] | [1 20 1000] | FALSE | 98543 | 23182 | 4.25 |
| [20 50 10] | [7575 15 1] | FALSE | 112414 | 23200 | 4.85 |
| [4 25 4 25] | [1 16 4 400] | FALSE | 81775 | 25920 | 3.15 |
| [4 25 4 25] | [5859 217 31 1] | FALSE | 83263 | 23857 | 3.49 |
| [12 3 4 5 6] | [360 120 6 24 1] | FALSE | 120350 | 24213 | 4.97 |
| [12 3 4 5 6] | [5760 960 120 12 1] | FALSE | 97231 | 25528 | 3.81 |
bfp16 backward
| input_size | stride_size | cont | ROCm | MIOpen | Improvement |
|---|---|---|---|---|---|
| [256 4 8732] | [34928 8732 1] | TRUE | 386314 | 255394 | 1.51 |
| [32 80 870] | [69600 1 80] | FALSE | 209828 | 73315 | 2.86 |
| [32 80 870] | [69600 870 1] | TRUE | 113049 | 67448 | 1.68 |
| [4 182403 91] | [16598673 91 1] | TRUE | 2712646 | 1870170 | 1.45 |
| [1534680] | [1] | TRUE | 80653 | 47750 | 1.69 |
| [16 1 512 512] | [262144 262144 512 1] | TRUE | 196600 | 122062 | 1.61 |
| [2 3 160 160] | [6528000 2176000 13600 85] | FALSE | 44174 | 18808 | 2.35 |
| [2 3 80 80] | [1632000 544000 6800 85] | FALSE | 31758 | 10791 | 2.94 |
| [32756 80] | [85 1] | FALSE | 137530 | 78310 | 1.76 |
| [64 3 80 80] | [1632000 544000 6800 85] | FALSE | 150010 | 96212 | 1.56 |
| [64 3 40 40] | [408000 136000 3400 85] | FALSE | 58142 | 29777 | 1.95 |
| [22311 80] | [85 1] | FALSE | 99308 | 54790 | 1.81 |
| [64 3 20 20] | [102000 34000 1700 85] | FALSE | 37006 | 13671 | 2.71 |
| [8 4] | [4 1] | TRUE | 28847 | 12355 | 2.33 |
| [56 4] | [4 1] | TRUE | 25599 | 12853 | 1.99 |
| [131 4] | [4 1] | TRUE | 31311 | 11715 | 2.67 |
| [10000] | [1] | TRUE | 29471 | 8604 | 3.43 |
| [200 50] | [50 1] | TRUE | 29776 | 8408 | 3.54 |
| [20 50 10] | [500 10 1] | TRUE | 29488 | 8586 | 3.43 |
| [4 25 4 25] | [2500 100 25 1] | TRUE | 29567 | 8711 | 3.39 |
| [12 3 4 5 6] | [360 120 30 6 1] | TRUE | 31199 | 9706 | 3.21 |
| [10000] | [3] | FALSE | 29391 | 8640 | 3.40 |
| [200 50] | [1 200] | FALSE | 39599 | 9226 | 4.29 |
| [200 50] | [505 1] | FALSE | 33711 | 8640 | 3.90 |
| [20 50 10] | [1 20 1000] | FALSE | 38655 | 9013 | 4.29 |
| [20 50 10] | [7575 15 1] | FALSE | 33711 | 8515 | 3.96 |
| [4 25 4 25] | [1 16 4 400] | FALSE | 38272 | 9208 | 4.16 |
| [4 25 4 25] | [5859 217 31 1] | FALSE | 30335 | 8888 | 3.41 |
| [12 3 4 5 6] | [360 120 6 24 1] | FALSE | 41151 | 9777 | 4.21 |
| [12 3 4 5 6] | [5760 960 120 12 1] | FALSE | 36864 | 9671 | 3.81 |
@CAHEK7 this PR has a very similar code structure to this MultiMarginLoss PR you have reviewed before.
@littlecutebird the PR has some build issues, could you follow up with the fix? Thanks!
@littlecutebird the PR has some build issues, could you follow up with the fix? Thanks!
@junliume please take a look at https://github.com/ROCm/MIOpen/pull/3166#issuecomment-2382642934
@junliume CI/CD passed