MIOpen icon indicating copy to clipboard operation
MIOpen copied to clipboard

Implement SGD

Open hieule88 opened this issue 10 months ago • 0 comments

  • Added SGD forward and backward.

  • Added driver test and gtest for SGD.

  • New API is guarded by MIOPEN_BETA_API macro.

  • Average over all cases:

  • SGD

Type Forward
float16 5.97
float32 7.67
bfloat16 5.99

FP16
op_name dtype input_size weight_decay contiguous rocm_kernel_avg MIOPEN MIOPEN_over_Rocm
SGDForward float16 [640 320 3 3] 0.0005 contiguous 51455 4747 10.83947756
SGDForward float16 [640 320 3 3] 0.0005 noncontiguous 52047 9423 5.523400191
SGDForward float16 [640 512 3 3] 0.0005 contiguous 81039 4711 17.20208024
SGDForward float16 [640 512 3 3] 0.0005 noncontiguous 80431 9192 8.75010879
SGDForward float16 [640 640 1 1] 0.0005 contiguous 43888 4924 8.913078798
SGDForward float16 [640 640 1 1] 0.0005 noncontiguous 44400 9032 4.915854739
SGDForward float16 [768 3072] 0.01 contiguous 65984 4587 14.38500109
SGDForward float16 [768 3072] 0.01 noncontiguous 65295 9032 7.229295837
SGDForward float16 [768 768] 0.01 contiguous 44496 4907 9.067862238
SGDForward float16 [768 768] 0.01 noncontiguous 45168 9458 4.77563967
SGDForward float16 [8 64] 0 contiguous 6064 4675 1.297112299
SGDForward float16 [8 64] 0 noncontiguous 5936 8836 0.671797193
SGDForward float16 [80 1536] 0 contiguous 22368 5049 4.430184195
SGDForward float16 [80 1536] 0 noncontiguous 22544 9440 2.388135593
SGDForward float16 [80 512 5] 0 contiguous 23072 4729 4.878832734
SGDForward float16 [80 512 5] 0 noncontiguous 22928 8907 2.574155159
SGDForward float16 [9521 512] 0 contiguous 61200 4711 12.99087243
SGDForward float16 [9521 512] 0 noncontiguous 60240 9529 6.321754644

FP32
op_name dtype input_size weight_decay contiguous rocm_kernel_avg MIOPEN MIOPEN_over_Rocm
SGDForward float32 [640 320 3 3] 0.0005 contiguous 74768 4729 15.81053077
SGDForward float32 [640 320 3 3] 0.0005 noncontiguous 75199 9227 8.149886204
SGDForward float32 [640 512 3 3] 0.0005 contiguous 121951 4675 26.0857754
SGDForward float32 [640 512 3 3] 0.0005 noncontiguous 122479 9316 13.14716617
SGDForward float32 [640 640 1 1] 0.0005 contiguous 48464 4835 10.02357808
SGDForward float32 [640 640 1 1] 0.0005 noncontiguous 48544 9725 4.991670951
SGDForward float32 [768 3072] 0.01 contiguous 103551 4675 22.14994652
SGDForward float32 [768 3072] 0.01 noncontiguous 102191 9227 11.07521405
SGDForward float32 [768 768] 0.01 contiguous 52304 5191 10.0759006
SGDForward float32 [768 768] 0.01 noncontiguous 52575 9369 5.611591419
SGDForward float32 [8 64] 0 contiguous 6784 4764 1.424013434
SGDForward float32 [8 64] 0 noncontiguous 6336 9191 0.689370036
SGDForward float32 [80 1536] 0 contiguous 24816 5280 4.7
SGDForward float32 [80 1536] 0 noncontiguous 24128 9654 2.499274912
SGDForward float32 [80 512 5] 0 contiguous 24432 4640 5.265517241
SGDForward float32 [80 512 5] 0 noncontiguous 24416 9209 2.651319361
SGDForward float32 [9521 512] 0 contiguous 87679 4765 18.40062959
SGDForward float32 [9521 512] 0 noncontiguous 89471 9920 9.019254032

BFP16
op_name dtype input_size weight_decay contiguous rocm_kernel_avg MIOPEN MIOPEN_over_Rocm
SGDForward bfloat16 [640 320 3 3] 0.0005 contiguous 56320 5156 10.92319628
SGDForward bfloat16 [640 320 3 3] 0.0005 noncontiguous 56016 8907 6.288986191
SGDForward bfloat16 [640 512 3 3] 0.0005 contiguous 81760 5813 14.06502666
SGDForward bfloat16 [640 512 3 3] 0.0005 noncontiguous 80863 9672 8.360525227
SGDForward bfloat16 [640 640 1 1] 0.0005 contiguous 49680 5298 9.377123443
SGDForward bfloat16 [640 640 1 1] 0.0005 noncontiguous 49999 9405 5.316214779
SGDForward bfloat16 [768 3072] 0.01 contiguous 67791 5226 12.97187141
SGDForward bfloat16 [768 3072] 0.01 noncontiguous 68079 9743 6.987478189
SGDForward bfloat16 [768 768] 0.01 contiguous 50671 5245 9.660819828
SGDForward bfloat16 [768 768] 0.01 noncontiguous 50912 9191 5.539331955
SGDForward bfloat16 [8 64] 0 contiguous 5488 5369 1.022164276
SGDForward bfloat16 [8 64] 0 noncontiguous 5776 9316 0.620008587
SGDForward bfloat16 [80 1536] 0 contiguous 24432 5173 4.722984728
SGDForward bfloat16 [80 1536] 0 noncontiguous 25200 9796 2.572478563
SGDForward bfloat16 [80 512 5] 0 contiguous 25440 5547 4.586262845
SGDForward bfloat16 [80 512 5] 0 noncontiguous 25328 9334 2.713520463
SGDForward bfloat16 [9521 512] 0 contiguous 60848 5600 10.86571429
SGDForward bfloat16 [9521 512] 0 noncontiguous 60832 9334 6.517248768

hieule88 avatar Mar 10 '25 06:03 hieule88