MIOpen icon indicating copy to clipboard operation
MIOpen copied to clipboard

Implement MarginRankingLoss

Open hieule88 opened this issue 10 months ago • 0 comments

  • Added MarginRankingLoss forward and backward.

  • Added driver test and gtest for MarginRankingLoss.

  • New API is guarded by MIOPEN_BETA_API macro.

  • Average over all cases:

Type Forward Backward
float16 6.20 2.61
float32 5.69 2.29
bfloat16 5.22 2.64
FWD-FP16
op_name dtype size contiguous reduction direction rocm_op_avg miopen_kernel_time improvement over rocm
marginrankingloss float16 [3 28 28] contiguous sum fwd 204094 31892 6.399535934
marginrankingloss float16 [3 28 28] noncontiguous sum fwd 200862 29120 6.897733516
marginrankingloss float16 [2 28 28] contiguous sum fwd 100847 27769 3.631639598
marginrankingloss float16 [2 28 28] noncontiguous sum fwd 182878 29724 6.152536671
marginrankingloss float16 [28 28 28] contiguous sum fwd 137231 26987 5.085078001
marginrankingloss float16 [28 28 28] noncontiguous sum fwd 125135 25156 4.974359994
marginrankingloss float16 [12 28 28] contiguous sum fwd 190126 25173 7.552774798
marginrankingloss float16 [12 28 28] noncontiguous sum fwd 134703 24604 5.474841489
marginrankingloss float16 [1 28 28] contiguous sum fwd 142303 26507 5.368506432
marginrankingloss float16 [1 28 28] noncontiguous sum fwd 145871 26916 5.419490266
marginrankingloss float16 [6 28 28] contiguous sum fwd 102863 26027 3.952165059
marginrankingloss float16 [6 28 28] noncontiguous sum fwd 158063 26344 5.999962041
marginrankingloss float16 [34 28 28] contiguous sum fwd 145663 24122 6.038595473
marginrankingloss float16 [34 28 28] noncontiguous sum fwd 95967 26593 3.608731621
marginrankingloss float16 [16 1 512 512] contiguous sum fwd 246494 160926 1.531722655
marginrankingloss float16 [16 1 512 512] noncontiguous sum fwd 329165 853499 0.385665361
marginrankingloss float16 [273 80] contiguous sum fwd 101471 30112 3.369786132
marginrankingloss float16 [273 80] noncontiguous sum fwd 107183 24815 4.319282692
marginrankingloss float16 [16 3 160 160] contiguous sum fwd 195518 65611 2.979957629
marginrankingloss float16 [16 3 160 160] noncontiguous sum fwd 184574 160588 1.149363589
FWD-FP32
op_name dtype size contiguous reduction direction rocm_op_avg miopen_kernel_time improvement over rocm
marginrankingloss float32 [3 28 28] contiguous sum fwd 102399 27270 3.755005501
marginrankingloss float32 [3 28 28] noncontiguous sum fwd 102079 27999 3.645808779
marginrankingloss float32 [2 28 28] contiguous sum fwd 107775 29688 3.630254648
marginrankingloss float32 [2 28 28] noncontiguous sum fwd 105871 28426 3.724442412
marginrankingloss float32 [28 28 28] contiguous sum fwd 107759 28817 3.739424645
marginrankingloss float32 [28 28 28] noncontiguous sum fwd 102047 26186 3.897006034
marginrankingloss float32 [12 28 28] contiguous sum fwd 99919 36852 2.711358949
marginrankingloss float32 [12 28 28] noncontiguous sum fwd 101839 25670 3.967238021
marginrankingloss float32 [1 28 28] contiguous sum fwd 99119 29439 3.366928224
marginrankingloss float32 [1 28 28] noncontiguous sum fwd 102495 27591 3.714798304
marginrankingloss float32 [6 28 28] contiguous sum fwd 103039 25830 3.989121177
marginrankingloss float32 [6 28 28] noncontiguous sum fwd 100335 26631 3.767601667
marginrankingloss float32 [34 28 28] contiguous sum fwd 101583 24959 4.069994791
marginrankingloss float32 [34 28 28] noncontiguous sum fwd 104399 30968 3.371189615
marginrankingloss float32 [16 1 512 512] contiguous sum fwd 277933 161009 1.72619543
marginrankingloss float32 [16 1 512 512] noncontiguous sum fwd 675930 1664723 0.406031514
marginrankingloss float32 [273 80] contiguous sum fwd 109071 28639 3.80847795
marginrankingloss float32 [273 80] noncontiguous sum fwd 101871 25386 4.012881116
marginrankingloss float32 [16 3 160 160] contiguous sum fwd 193854 65296 2.968849547
marginrankingloss float32 [16 3 160 160] noncontiguous sum fwd 182814 319331 0.572490613
FWD-BFP16
op_name dtype size contiguous reduction direction rocm_op_avg miopen_kernel_time improvement over rocm
marginrankingloss bfloat16 [3 28 28] contiguous sum fwd 161647 27999 5.773313333
marginrankingloss bfloat16 [3 28 28] noncontiguous sum fwd 129103 27786 4.646332686
marginrankingloss bfloat16 [2 28 28] contiguous sum fwd 108015 28977 3.727611554
marginrankingloss bfloat16 [2 28 28] noncontiguous sum fwd 120863 28373 4.259789236
marginrankingloss bfloat16 [28 28 28] contiguous sum fwd 109615 26044 4.208838888
marginrankingloss bfloat16 [28 28 28] noncontiguous sum fwd 114063 25902 4.403636785
marginrankingloss bfloat16 [12 28 28] contiguous sum fwd 116623 28657 4.069616499
marginrankingloss bfloat16 [12 28 28] noncontiguous sum fwd 108079 25955 4.164091697
marginrankingloss bfloat16 [1 28 28] contiguous sum fwd 107423 28195 3.810001773
marginrankingloss bfloat16 [1 28 28] noncontiguous sum fwd 111679 26879 4.154879274
marginrankingloss bfloat16 [6 28 28] contiguous sum fwd 104831 26471 3.960220619
marginrankingloss bfloat16 [6 28 28] noncontiguous sum fwd 114511 26702 4.288480264
marginrankingloss bfloat16 [34 28 28] contiguous sum fwd 113983 24640 4.625933442
marginrankingloss bfloat16 [34 28 28] noncontiguous sum fwd 107503 26613 4.039491978
marginrankingloss bfloat16 [16 1 512 512] contiguous sum fwd 245870 160566 1.53127063
marginrankingloss bfloat16 [16 1 512 512] noncontiguous sum fwd 323149 838934 0.385190015
marginrankingloss bfloat16 [273 80] contiguous sum fwd 89775 25582 3.509303416
marginrankingloss bfloat16 [273 80] noncontiguous sum fwd 113759 25528 4.456244124
marginrankingloss bfloat16 [16 3 160 160] contiguous sum fwd 116319 64407 1.805999348
marginrankingloss bfloat16 [16 3 160 160] noncontiguous sum fwd 111167 158575 0.701037364
BWD-FP16
op_name dtype size contiguous reduction direction rocm_kernel_avg miopen_kernel_time improvement over rocm
marginrankingloss float16 [16 3 80 80] contiguous none bwd 24416 17831 1.369300656
marginrankingloss float16 [16 3 80 80] noncontiguous none bwd 48511 33902 1.43091853
marginrankingloss float16 [294 80] contiguous none bwd 33040 11004 3.002544529
marginrankingloss float16 [294 80] noncontiguous none bwd 43808 11502 3.808728917
marginrankingloss float16 [16 3 40 40] contiguous none bwd 22016 11271 1.953331559
marginrankingloss float16 [16 3 40 40] noncontiguous none bwd 36959 13618 2.713981495
marginrankingloss float16 [261 80] contiguous none bwd 27408 11164 2.455034038
marginrankingloss float16 [261 80] noncontiguous none bwd 40832 11164 3.657470441
marginrankingloss float16 [16 3 20 20] contiguous none bwd 29920 11147 2.684130259
marginrankingloss float16 [16 3 20 20] noncontiguous none bwd 43632 11324 3.853055457
BWD-FP32
op_name dtype size contiguous reduction direction rocm_kernel_avg miopen_kernel_time improvement over rocm
marginrankingloss float32 [16 3 80 80] contiguous none bwd 23488 18453 1.272855362
marginrankingloss float32 [16 3 80 80] noncontiguous none bwd 48607 36746 1.322783432
marginrankingloss float32 [294 80] contiguous none bwd 28144 11075 2.541218962
marginrankingloss float32 [294 80] noncontiguous none bwd 36848 11200 3.29
marginrankingloss float32 [16 3 40 40] contiguous none bwd 19840 11111 1.785617856
marginrankingloss float32 [16 3 40 40] noncontiguous none bwd 31904 15626 2.04172533
marginrankingloss float32 [261 80] contiguous none bwd 23360 10560 2.212121212
marginrankingloss float32 [261 80] noncontiguous none bwd 33152 11004 3.012722646
marginrankingloss float32 [16 3 20 20] contiguous none bwd 24528 10133 2.420605941
marginrankingloss float32 [16 3 20 20] noncontiguous none bwd 34656 11413 3.036537282
BWD-BFP16
op_name dtype size contiguous reduction direction rocm_kernel_avg miopen_kernel_time improvement over rocm
marginrankingloss bfloat16 [16 3 80 80] contiguous sum bwd 24128 16622 1.451570208
marginrankingloss bfloat16 [16 3 80 80] noncontiguous sum bwd 48144 33688 1.429114225
marginrankingloss bfloat16 [294 80] contiguous sum bwd 31088 9546 3.256652001
marginrankingloss bfloat16 [294 80] noncontiguous sum bwd 42928 9795 4.382644206
marginrankingloss bfloat16 [16 3 40 40] contiguous sum bwd 21328 9813 2.173443391
marginrankingloss bfloat16 [16 3 40 40] noncontiguous sum bwd 35296 13493 2.615874898
marginrankingloss bfloat16 [261 80] contiguous sum bwd 27232 9866 2.760186499
marginrankingloss bfloat16 [261 80] noncontiguous sum bwd 39024 9849 3.962229668
marginrankingloss bfloat16 [16 3 20 20] contiguous sum bwd 28320 9422 3.005731267
marginrankingloss bfloat16 [16 3 20 20] noncontiguous sum bwd 40288 10186 3.955232672

hieule88 avatar Mar 10 '25 06:03 hieule88