MIOpen
MIOpen copied to clipboard
Implement SGD
-
Added SGD forward and backward.
-
Added driver test and gtest for SGD.
-
New API is guarded by MIOPEN_BETA_API macro.
-
Average over all cases:
-
SGD
| Type | Forward |
|---|---|
| float16 | 5.97 |
| float32 | 7.67 |
| bfloat16 | 5.99 |
FP16
| op_name | dtype | input_size | weight_decay | contiguous | rocm_kernel_avg | MIOPEN | MIOPEN_over_Rocm |
|---|---|---|---|---|---|---|---|
| SGDForward | float16 | [640 320 3 3] | 0.0005 | contiguous | 51455 | 4747 | 10.83947756 |
| SGDForward | float16 | [640 320 3 3] | 0.0005 | noncontiguous | 52047 | 9423 | 5.523400191 |
| SGDForward | float16 | [640 512 3 3] | 0.0005 | contiguous | 81039 | 4711 | 17.20208024 |
| SGDForward | float16 | [640 512 3 3] | 0.0005 | noncontiguous | 80431 | 9192 | 8.75010879 |
| SGDForward | float16 | [640 640 1 1] | 0.0005 | contiguous | 43888 | 4924 | 8.913078798 |
| SGDForward | float16 | [640 640 1 1] | 0.0005 | noncontiguous | 44400 | 9032 | 4.915854739 |
| SGDForward | float16 | [768 3072] | 0.01 | contiguous | 65984 | 4587 | 14.38500109 |
| SGDForward | float16 | [768 3072] | 0.01 | noncontiguous | 65295 | 9032 | 7.229295837 |
| SGDForward | float16 | [768 768] | 0.01 | contiguous | 44496 | 4907 | 9.067862238 |
| SGDForward | float16 | [768 768] | 0.01 | noncontiguous | 45168 | 9458 | 4.77563967 |
| SGDForward | float16 | [8 64] | 0 | contiguous | 6064 | 4675 | 1.297112299 |
| SGDForward | float16 | [8 64] | 0 | noncontiguous | 5936 | 8836 | 0.671797193 |
| SGDForward | float16 | [80 1536] | 0 | contiguous | 22368 | 5049 | 4.430184195 |
| SGDForward | float16 | [80 1536] | 0 | noncontiguous | 22544 | 9440 | 2.388135593 |
| SGDForward | float16 | [80 512 5] | 0 | contiguous | 23072 | 4729 | 4.878832734 |
| SGDForward | float16 | [80 512 5] | 0 | noncontiguous | 22928 | 8907 | 2.574155159 |
| SGDForward | float16 | [9521 512] | 0 | contiguous | 61200 | 4711 | 12.99087243 |
| SGDForward | float16 | [9521 512] | 0 | noncontiguous | 60240 | 9529 | 6.321754644 |
FP32
| op_name | dtype | input_size | weight_decay | contiguous | rocm_kernel_avg | MIOPEN | MIOPEN_over_Rocm |
|---|---|---|---|---|---|---|---|
| SGDForward | float32 | [640 320 3 3] | 0.0005 | contiguous | 74768 | 4729 | 15.81053077 |
| SGDForward | float32 | [640 320 3 3] | 0.0005 | noncontiguous | 75199 | 9227 | 8.149886204 |
| SGDForward | float32 | [640 512 3 3] | 0.0005 | contiguous | 121951 | 4675 | 26.0857754 |
| SGDForward | float32 | [640 512 3 3] | 0.0005 | noncontiguous | 122479 | 9316 | 13.14716617 |
| SGDForward | float32 | [640 640 1 1] | 0.0005 | contiguous | 48464 | 4835 | 10.02357808 |
| SGDForward | float32 | [640 640 1 1] | 0.0005 | noncontiguous | 48544 | 9725 | 4.991670951 |
| SGDForward | float32 | [768 3072] | 0.01 | contiguous | 103551 | 4675 | 22.14994652 |
| SGDForward | float32 | [768 3072] | 0.01 | noncontiguous | 102191 | 9227 | 11.07521405 |
| SGDForward | float32 | [768 768] | 0.01 | contiguous | 52304 | 5191 | 10.0759006 |
| SGDForward | float32 | [768 768] | 0.01 | noncontiguous | 52575 | 9369 | 5.611591419 |
| SGDForward | float32 | [8 64] | 0 | contiguous | 6784 | 4764 | 1.424013434 |
| SGDForward | float32 | [8 64] | 0 | noncontiguous | 6336 | 9191 | 0.689370036 |
| SGDForward | float32 | [80 1536] | 0 | contiguous | 24816 | 5280 | 4.7 |
| SGDForward | float32 | [80 1536] | 0 | noncontiguous | 24128 | 9654 | 2.499274912 |
| SGDForward | float32 | [80 512 5] | 0 | contiguous | 24432 | 4640 | 5.265517241 |
| SGDForward | float32 | [80 512 5] | 0 | noncontiguous | 24416 | 9209 | 2.651319361 |
| SGDForward | float32 | [9521 512] | 0 | contiguous | 87679 | 4765 | 18.40062959 |
| SGDForward | float32 | [9521 512] | 0 | noncontiguous | 89471 | 9920 | 9.019254032 |
BFP16
| op_name | dtype | input_size | weight_decay | contiguous | rocm_kernel_avg | MIOPEN | MIOPEN_over_Rocm |
|---|---|---|---|---|---|---|---|
| SGDForward | bfloat16 | [640 320 3 3] | 0.0005 | contiguous | 56320 | 5156 | 10.92319628 |
| SGDForward | bfloat16 | [640 320 3 3] | 0.0005 | noncontiguous | 56016 | 8907 | 6.288986191 |
| SGDForward | bfloat16 | [640 512 3 3] | 0.0005 | contiguous | 81760 | 5813 | 14.06502666 |
| SGDForward | bfloat16 | [640 512 3 3] | 0.0005 | noncontiguous | 80863 | 9672 | 8.360525227 |
| SGDForward | bfloat16 | [640 640 1 1] | 0.0005 | contiguous | 49680 | 5298 | 9.377123443 |
| SGDForward | bfloat16 | [640 640 1 1] | 0.0005 | noncontiguous | 49999 | 9405 | 5.316214779 |
| SGDForward | bfloat16 | [768 3072] | 0.01 | contiguous | 67791 | 5226 | 12.97187141 |
| SGDForward | bfloat16 | [768 3072] | 0.01 | noncontiguous | 68079 | 9743 | 6.987478189 |
| SGDForward | bfloat16 | [768 768] | 0.01 | contiguous | 50671 | 5245 | 9.660819828 |
| SGDForward | bfloat16 | [768 768] | 0.01 | noncontiguous | 50912 | 9191 | 5.539331955 |
| SGDForward | bfloat16 | [8 64] | 0 | contiguous | 5488 | 5369 | 1.022164276 |
| SGDForward | bfloat16 | [8 64] | 0 | noncontiguous | 5776 | 9316 | 0.620008587 |
| SGDForward | bfloat16 | [80 1536] | 0 | contiguous | 24432 | 5173 | 4.722984728 |
| SGDForward | bfloat16 | [80 1536] | 0 | noncontiguous | 25200 | 9796 | 2.572478563 |
| SGDForward | bfloat16 | [80 512 5] | 0 | contiguous | 25440 | 5547 | 4.586262845 |
| SGDForward | bfloat16 | [80 512 5] | 0 | noncontiguous | 25328 | 9334 | 2.713520463 |
| SGDForward | bfloat16 | [9521 512] | 0 | contiguous | 60848 | 5600 | 10.86571429 |
| SGDForward | bfloat16 | [9521 512] | 0 | noncontiguous | 60832 | 9334 | 6.517248768 |