MIOpen icon indicating copy to clipboard operation
MIOpen copied to clipboard

Implement MatrixDiag, MatrixSetDiag, MatrixDiagPart

Open long10024070 opened this issue 10 months ago • 0 comments

  • Added MatrixDiag, MatrixSetDiag, MatrixDiagPart forward and backward.
  • Added driver test and gtest for both direction.
  • New APIs are guarded by MIOPEN_BETA_API macro.
  • Compare to ROCm pytorch:
MatrixDiag float16
Op_name dtype input_size output_size direction ROCm pytorch MIOpen HIP Improvement
MatrixDiagV3 float16 [512 2] [512 2 2] fwd 26960 4989 5.40
MatrixDiagV3 float16 [1024 4] [1024 4 4] fwd 26080 6217 4.19
MatrixDiagV3 float16 [256 8] [256 8 8] fwd 23296 5246 4.44
MatrixDiagV3 float16 [16 16] [16 16 16] fwd 26848 4465 6.01
MatrixDiagV3 float16 [32 32] [32 32 32] fwd 30080 4656 6.46
MatrixDiagV3 float16 [128 64] [128 64 64] fwd 32320 12218 2.65
MatrixDiagV3 float16 [32 128] [32 128 128] fwd 25520 11127 2.29
MatrixDiagV3 float16 [8 256] [8 256 256] fwd 25200 10278 2.45
MatrixDiagV3 float16 [2 512] [2 512 512] fwd 25200 9275 2.72
MatrixDiagV3 float16 [4 512] [4 512 512] fwd 26976 15757 1.71
MatrixDiagV3 float16 [512 2] [512 2 2] bwd 8288 4809 1.72
MatrixDiagV3 float16 [1024 4] [1024 4 4] bwd 8240 5556 1.48
MatrixDiagV3 float16 [1024 8] [1024 8 8] bwd 8560 5599 1.53
MatrixDiagV3 float16 [512 16] [512 16 16] bwd 8752 5730 1.53
MatrixDiagV3 float16 [1024 32] [1024 32 32] bwd 9872 7508 1.31
MatrixDiagV3 float16 [256 64] [256 64 64] bwd 9776 6348 1.54
MatrixDiagV3 float16 [1024 128] [1024 128 128] bwd 23088 14896 1.55
MatrixDiagV3 float16 [1024 256] [1024 256 256] bwd 45456 25973 1.75
MatrixDiagV3 float16 [8 512] [8 512 512] bwd 9488 5759 1.65
MatrixDiagV3 float16 [8 1024] [8 1024 1024] bwd 9856 6387 1.54
MatrixDiag float32
Op_name dtype input_size output_size direction ROCm pytorch MIOpen HIP Improvement
MatrixDiagV3 float32 [512 2] [512 2 2] fwd 36592 5950 6.15
MatrixDiagV3 float32 [1024 4] [1024 4 4] fwd 24928 6275 3.97
MatrixDiagV3 float32 [128 8] [128 8 8] fwd 24208 5166 4.69
MatrixDiagV3 float32 [32 16] [32 16 16] fwd 24096 4741 5.08
MatrixDiagV3 float32 [256 16] [256 16 16] fwd 29248 6244 4.68
MatrixDiagV3 float32 [16 32] [16 32 32] fwd 23440 4549 5.15
MatrixDiagV3 float32 [128 64] [128 64 64] fwd 27183 11893 2.29
MatrixDiagV3 float32 [64 128] [64 128 128] fwd 28624 18249 1.57
MatrixDiagV3 float32 [2 512] [2 512 512] fwd 26704 9139 2.92
MatrixDiagV3 float32 [4 512] [4 512 512] fwd 28272 15775 1.79
MatrixDiagV3 float32 [1024 2] [1024 2 2] bwd 8400 5632 1.49
MatrixDiagV3 float32 [1024 4] [1024 4 4] bwd 8464 5908 1.43
MatrixDiagV3 float32 [1024 8] [1024 8 8] bwd 8784 5803 1.51
MatrixDiagV3 float32 [512 16] [512 16 16] bwd 9168 5989 1.53
MatrixDiagV3 float32 [64 32] [64 32 32] bwd 9072 5693 1.59
MatrixDiagV3 float32 [128 64] [128 64 64] bwd 9696 5932 1.63
MatrixDiagV3 float32 [1024 128] [1024 128 128] bwd 21648 14860 1.46
MatrixDiagV3 float32 [64 256] [64 256 256] bwd 9808 6556 1.50
MatrixDiagV3 float32 [32 512] [32 512 512] bwd 9904 7350 1.35
MatrixDiagV3 float32 [16 1024] [16 1024 1024] bwd 10576 7815 1.35
MatrixDiag bfloat16
Op_name dtype input_size output_size direction ROCm pytorch MIOpen HIP Improvement
MatrixDiagV3 bfloat16 [1024 2] [1024 2 2] fwd 23968 5591 4.29
MatrixDiagV3 bfloat16 [1024 4] [1024 4 4] fwd 23536 6083 3.87
MatrixDiagV3 bfloat16 [512 8] [512 8 8] fwd 23376 6075 3.85
MatrixDiagV3 bfloat16 [64 16] [64 16 16] fwd 48000 4694 10.23
MatrixDiagV3 bfloat16 [8 32] [8 32 32] fwd 28000 4261 6.57
MatrixDiagV3 bfloat16 [256 64] [256 64 64] fwd 30000 20371 1.47
MatrixDiagV3 bfloat16 [32 128] [32 128 128] fwd 33664 11036 3.05
MatrixDiagV3 bfloat16 [16 256] [16 256 256] fwd 45136 17489 2.58
MatrixDiagV3 bfloat16 [32 256] [32 256 256] fwd 39952 31324 1.28
MatrixDiagV3 bfloat16 [2 512] [2 512 512] fwd 29712 9354 3.18
MatrixDiagV3 bfloat16 [512 2] [512 2 2] bwd 7824 4905 1.60
MatrixDiagV3 bfloat16 [1024 4] [1024 4 4] bwd 7952 5675 1.40
MatrixDiagV3 bfloat16 [1024 8] [1024 8 8] bwd 8400 5643 1.49
MatrixDiagV3 bfloat16 [256 16] [256 16 16] bwd 8400 5738 1.46
MatrixDiagV3 bfloat16 [256 32] [256 32 32] bwd 8768 5999 1.46
MatrixDiagV3 bfloat16 [512 64] [512 64 64] bwd 10400 7888 1.32
MatrixDiagV3 bfloat16 [1024 128] [1024 128 128] bwd 22512 14873 1.51
MatrixDiagV3 bfloat16 [1024 256] [1024 256 256] bwd 45648 25898 1.76
MatrixDiagV3 bfloat16 [16 512] [16 512 512] bwd 9536 6073 1.57
MatrixDiagV3 bfloat16 [16 1024] [16 1024 1024] bwd 10192 7576 1.35
MatrixSetDiag float16
Op_name dtype input_size diag_size direction ROCm pytorch MIOpen HIP Improvement
MatrixSetDiagV3 float16 [256 2 2] [256 2] fwd 48368 6112 7.91
MatrixSetDiagV3 float16 [1024 4 4] [1024 4] fwd 45280 6025 7.52
MatrixSetDiagV3 float16 [128 8 8] [128 8] fwd 51632 5700 9.06
MatrixSetDiagV3 float16 [512 16 16] [512 16] fwd 44320 6419 6.90
MatrixSetDiagV3 float16 [128 32 32] [128 32] fwd 36944 6490 5.69
MatrixSetDiagV3 float16 [16 64 64] [16 64] fwd 37104 6235 5.95
MatrixSetDiagV3 float16 [4 128 128] [4 128] fwd 35823 6054 5.92
MatrixSetDiagV3 float16 [4 256 256] [4 256] fwd 27824 9779 2.85
MatrixSetDiagV3 float16 [8 256 256] [8 256] fwd 29552 11979 2.47
MatrixSetDiagV3 float16 [2 512 512] [2 512] fwd 32800 11326 2.90
MatrixSetDiagV3 float16 [128 2 2] [128 2] bwd 190734 20290 9.40
MatrixSetDiagV3 float16 [1024 4 4] [1024 4] bwd 203006 23807 8.53
MatrixSetDiagV3 float16 [1024 8 8] [1024 8] bwd 153311 18052 8.49
MatrixSetDiagV3 float16 [512 16 16] [512 16] bwd 153647 20780 7.39
MatrixSetDiagV3 float16 [512 32 32] [512 32] bwd 120255 25044 4.80
MatrixSetDiagV3 float16 [64 64 64] [64 64] bwd 135743 21401 6.34
MatrixSetDiagV3 float16 [64 128 128] [64 128] bwd 134079 33490 4.00
MatrixSetDiagV3 float16 [16 256 256] [16 256] bwd 154383 32471 4.75
MatrixSetDiagV3 float16 [8 512 512] [8 512] bwd 136111 43920 3.10
MatrixSetDiagV3 float16 [4 1024 1024] [4 1024] bwd 156191 68072 2.29
MatrixSetDiag float32
Op_name dtype input_size diag_size direction ROCm pytorch MIOpen HIP Improvement
MatrixSetDiagV3 float32 [1024 2 2] [1024 2] fwd 28512 5879 4.85
MatrixSetDiagV3 float32 [512 4 4] [512 4] fwd 29104 5763 5.05
MatrixSetDiagV3 float32 [1024 8 8] [1024 8] fwd 29712 6067 4.90
MatrixSetDiagV3 float32 [512 16 16] [512 16] fwd 30064 6170 4.87
MatrixSetDiagV3 float32 [128 32 32] [128 32] fwd 29216 6312 4.63
MatrixSetDiagV3 float32 [32 64 64] [32 64] fwd 33680 6155 5.47
MatrixSetDiagV3 float32 [8 128 128] [8 128] fwd 28800 6094 4.73
MatrixSetDiagV3 float32 [2 256 256] [2 256] fwd 28432 5976 4.76
MatrixSetDiagV3 float32 [1024 2 2] [1024 2] bwd 162623 21037 7.73
MatrixSetDiagV3 float32 [64 4 4] [64 4] bwd 141951 17181 8.26
MatrixSetDiagV3 float32 [512 8 8] [512 8] bwd 136095 18094 7.52
MatrixSetDiagV3 float32 [1024 16 16] [1024 16] bwd 134687 22277 6.05
MatrixSetDiagV3 float32 [512 32 32] [512 32] bwd 134271 25808 5.20
MatrixSetDiagV3 float32 [128 64 64] [128 64] bwd 172558 29060 5.94
MatrixSetDiagV3 float32 [16 128 128] [16 128] bwd 151167 22422 6.74
MatrixSetDiagV3 float32 [2 256 256] [2 256] bwd 142543 22317 6.39
MatrixSetDiagV3 float32 [8 512 512] [8 512] bwd 144399 45533 3.17
MatrixSetDiagV3 float32 [8 1024 1024] [8 1024] bwd 267102 127155 2.10
MatrixSetDiag bfloat16
Op_name dtype input_size diag_size direction ROCm pytorch MIOpen HIP Improvement
MatrixSetDiagV3 bfloat16 [1024 2 2] [1024 2] fwd 28976 5785 5.01
MatrixSetDiagV3 bfloat16 [1024 4 4] [1024 4] fwd 30384 5787 5.25
MatrixSetDiagV3 bfloat16 [512 8 8] [512 8] fwd 30128 5664 5.32
MatrixSetDiagV3 bfloat16 [512 16 16] [512 16] fwd 29168 6079 4.80
MatrixSetDiagV3 bfloat16 [256 32 32] [256 32] fwd 30320 9145 3.32
MatrixSetDiagV3 bfloat16 [64 64 64] [64 64] fwd 29744 9027 3.30
MatrixSetDiagV3 bfloat16 [32 128 128] [32 128] fwd 31056 11885 2.61
MatrixSetDiagV3 bfloat16 [4 256 256] [4 256] fwd 27922 8523 3.28
MatrixSetDiagV3 bfloat16 [8 256 256] [8 256] fwd 29856 11547 2.59
MatrixSetDiagV3 bfloat16 [2 512 512] [2 512] fwd 29296 10915 2.68
MatrixSetDiagV3 bfloat16 [1024 2 2] [1024 2] bwd 110719 19820 5.59
MatrixSetDiagV3 bfloat16 [512 4 4] [512 4] bwd 136687 18999 7.19
MatrixSetDiagV3 bfloat16 [1024 8 8] [1024 8] bwd 107055 17357 6.17
MatrixSetDiagV3 bfloat16 [1024 16 16] [1024 16] bwd 96287 23717 4.06
MatrixSetDiagV3 bfloat16 [128 32 32] [128 32] bwd 102703 18238 5.63
MatrixSetDiagV3 bfloat16 [32 64 64] [32 64] bwd 105007 17220 6.10
MatrixSetDiagV3 bfloat16 [64 128 128] [64 128] bwd 117007 33493 3.49
MatrixSetDiagV3 bfloat16 [32 256 256] [32 256] bwd 112991 48159 2.35
MatrixSetDiagV3 bfloat16 [2 512 512] [2 512] bwd 123359 27704 4.45
MatrixSetDiagV3 bfloat16 [4 1024 1024] [4 1024] bwd 149454 90338 1.65
MatrixDiagPart float16
Op_name dtype input_size output_size direction ROCm pytorch MIOpen HIP Improvement
MatrixDiagPartV2 float16 [1024 4 4] [1024 4] fwd 8128 6361 1.28
MatrixDiagPartV2 float16 [1024 8 8] [1024 8] fwd 8112 6214 1.31
MatrixDiagPartV2 float16 [256 16 16] [256 16] fwd 27148 6175 4.40
MatrixDiagPartV2 float16 [1024 16 16] [1024 16] fwd 16841 6445 2.61
MatrixDiagPartV2 float16 [1024 32 32] [1024 32] fwd 27019 8140 3.32
MatrixDiagPartV2 float16 [512 64 64] [512 64] fwd 51552 8591 6.00
MatrixDiagPartV2 float16 [64 128 128] [64 128] fwd 16782 6627 2.53
MatrixDiagPartV2 float16 [4 512 512] [4 512] fwd 9040 6236 1.45
MatrixDiagPartV2 float16 [128 512 512] [128 512] fwd 38311 12350 3.10
MatrixDiagPartV2 float16 [16 1024 1024] [16 1024] fwd 10544 8147 1.29
MatrixDiagPartV2 float16 [512 2 2] [512 2] bwd 26208 4475 5.86
MatrixDiagPartV2 float16 [128 4 4] [128 4] bwd 23936 4240 5.65
MatrixDiagPartV2 float16 [128 8 8] [128 8] bwd 24560 4164 5.90
MatrixDiagPartV2 float16 [256 16 16] [256 16] bwd 36688 5732 6.40
MatrixDiagPartV2 float16 [16 32 32] [16 32] bwd 39408 3751 10.51
MatrixDiagPartV2 float16 [128 64 64] [128 64] bwd 34624 12522 2.77
MatrixDiagPartV2 float16 [64 128 128] [64 128] bwd 32591 17250 1.89
MatrixDiagPartV2 float16 [16 256 256] [16 256] bwd 32800 15927 2.06
MatrixDiagPartV2 float16 [2 512 512] [2 512] bwd 31984 8319 3.84
MatrixDiagPartV2 float16 [2 1024 1024] [2 1024] bwd 33664 23881 1.41
MatrixDiagPart float32
Op_name dtype input_size output_size direction ROCm pytorch MIOpen HIP Improvement
MatrixDiagPartV2 float32 [1024 2 2] [1024 2] fwd 8912 6238 1.43
MatrixDiagPartV2 float32 [1024 4 4] [1024 4] fwd 8096 6062 1.34
MatrixDiagPartV2 float32 [256 8 8] [256 8] fwd 8240 6060 1.36
MatrixDiagPartV2 float32 [512 16 16] [512 16] fwd 8944 6555 1.36
MatrixDiagPartV2 float32 [512 32 32] [512 32] fwd 8928 6985 1.28
MatrixDiagPartV2 float32 [128 64 64] [128 64] fwd 8896 6553 1.36
MatrixDiagPartV2 float32 [4 128 128] [4 128] fwd 12064 8001 1.51
MatrixDiagPartV2 float32 [8 256 256] [8 256] fwd 8608 5336 1.61
MatrixDiagPartV2 float32 [4 512 512] [4 512] fwd 9424 6301 1.50
MatrixDiagPartV2 float32 [8 1024 1024] [8 1024] fwd 10272 7243 1.42
MatrixDiagPartV2 float32 [1024 2 2] [1024 2] bwd 24848 5968 4.16
MatrixDiagPartV2 float32 [128 4 4] [128 4] bwd 23120 4525 5.11
MatrixDiagPartV2 float32 [2 8 8] [2 8] bwd 21504 6123 3.51
MatrixDiagPartV2 float32 [4 8 8] [4 8] bwd 20320 6366 3.19
MatrixDiagPartV2 float32 [64 8 8] [64 8] bwd 23360 4193 5.57
MatrixDiagPartV2 float32 [32 16 16] [32 16] bwd 23008 4127 5.57
MatrixDiagPartV2 float32 [2 32 32] [2 32] bwd 26560 4008 6.63
MatrixDiagPartV2 float32 [512 64 64] [512 64] bwd 57055 34863 1.64
MatrixDiagPartV2 float32 [2 128 128] [2 128] bwd 52672 3785 13.92
MatrixDiagPartV2 float32 [32 256 256] [32 256] bwd 64000 28904 2.21
MatrixDiagPartV2 float32 [4 1024 1024] [4 1024] bwd 59456 48589 1.22
MatrixDiagPart bfloat16
Op_name dtype input_size output_size direction ROCm pytorch MIOpen HIP Improvement
MatrixDiagPartV2 bfloat16 [8 256 256] [8 256] fwd 6960 5627 1.24
MatrixDiagPartV2 bfloat16 [16 256 256] [16 256] fwd 6832 5663 1.21
MatrixDiagPartV2 bfloat16 [2 512 512] [2 512] fwd 8704 6964 1.25
MatrixDiagPartV2 bfloat16 [4 512 512] [4 512] fwd 8976 7006 1.28
MatrixDiagPartV2 bfloat16 [8 512 512] [8 512] fwd 8096 6513 1.24
MatrixDiagPartV2 bfloat16 [16 512 512] [16 512] fwd 8816 6865 1.28
MatrixDiagPartV2 bfloat16 [2 1024 1024] [2 1024] fwd 8944 6433 1.39
MatrixDiagPartV2 bfloat16 [4 1024 1024] [4 1024] fwd 8960 6608 1.36
MatrixDiagPartV2 bfloat16 [8 1024 1024] [8 1024] fwd 9136 7138 1.28
MatrixDiagPartV2 bfloat16 [512 4 4] [512 4] bwd 23776 4601 5.17
MatrixDiagPartV2 bfloat16 [1024 4 4] [1024 4] bwd 28128 5604 5.02
MatrixDiagPartV2 bfloat16 [512 8 8] [512 8] bwd 23616 5659 4.17
MatrixDiagPartV2 bfloat16 [1024 8 8] [1024 8] bwd 27792 5688 4.89
MatrixDiagPartV2 bfloat16 [1024 32 32] [1024 32] bwd 27360 19128 1.43
MatrixDiagPartV2 bfloat16 [16 128 128] [16 128] bwd 24512 6574 3.73
MatrixDiagPartV2 bfloat16 [32 128 128] [32 128] bwd 26352 10118 2.60
MatrixDiagPartV2 bfloat16 [8 256 256] [8 256] bwd 33040 9427 3.50
MatrixDiagPartV2 bfloat16 [4 512 512] [4 512] bwd 27216 14506 1.88
MatrixDiagPartV2 bfloat16 [2 1024 1024] [2 1024] bwd 37472 24024 1.56
  • Average over all cases:
op_name type average
MatrixDiag float16 3.02
MatrixDiag float32 2.78
MatrixDiag bfloat16 2.97
MatrixSetDiag float16 5.33
MatrixSetDiag float32 4.82
MatrixSetDiag bfloat16 4.24
MatrixDiagPart float16 3.65
MatrixDiagPart float32 3.33
MatrixDiagPart bfloat16 3.23

long10024070 avatar Feb 13 '25 07:02 long10024070