Implement MatrixDiag, MatrixSetDiag, MatrixDiagPart
Open
long10024070
opened this issue 10 months ago
•
0 comments
- Added MatrixDiag, MatrixSetDiag, MatrixDiagPart forward and backward.
- Added driver test and gtest for both direction.
- New APIs are guarded by MIOPEN_BETA_API macro.
- Compare to ROCm pytorch:
MatrixDiag float16
| Op_name |
dtype |
input_size |
output_size |
direction |
ROCm pytorch |
MIOpen HIP |
Improvement |
| MatrixDiagV3 |
float16 |
[512 2] |
[512 2 2] |
fwd |
26960 |
4989 |
5.40 |
| MatrixDiagV3 |
float16 |
[1024 4] |
[1024 4 4] |
fwd |
26080 |
6217 |
4.19 |
| MatrixDiagV3 |
float16 |
[256 8] |
[256 8 8] |
fwd |
23296 |
5246 |
4.44 |
| MatrixDiagV3 |
float16 |
[16 16] |
[16 16 16] |
fwd |
26848 |
4465 |
6.01 |
| MatrixDiagV3 |
float16 |
[32 32] |
[32 32 32] |
fwd |
30080 |
4656 |
6.46 |
| MatrixDiagV3 |
float16 |
[128 64] |
[128 64 64] |
fwd |
32320 |
12218 |
2.65 |
| MatrixDiagV3 |
float16 |
[32 128] |
[32 128 128] |
fwd |
25520 |
11127 |
2.29 |
| MatrixDiagV3 |
float16 |
[8 256] |
[8 256 256] |
fwd |
25200 |
10278 |
2.45 |
| MatrixDiagV3 |
float16 |
[2 512] |
[2 512 512] |
fwd |
25200 |
9275 |
2.72 |
| MatrixDiagV3 |
float16 |
[4 512] |
[4 512 512] |
fwd |
26976 |
15757 |
1.71 |
| MatrixDiagV3 |
float16 |
[512 2] |
[512 2 2] |
bwd |
8288 |
4809 |
1.72 |
| MatrixDiagV3 |
float16 |
[1024 4] |
[1024 4 4] |
bwd |
8240 |
5556 |
1.48 |
| MatrixDiagV3 |
float16 |
[1024 8] |
[1024 8 8] |
bwd |
8560 |
5599 |
1.53 |
| MatrixDiagV3 |
float16 |
[512 16] |
[512 16 16] |
bwd |
8752 |
5730 |
1.53 |
| MatrixDiagV3 |
float16 |
[1024 32] |
[1024 32 32] |
bwd |
9872 |
7508 |
1.31 |
| MatrixDiagV3 |
float16 |
[256 64] |
[256 64 64] |
bwd |
9776 |
6348 |
1.54 |
| MatrixDiagV3 |
float16 |
[1024 128] |
[1024 128 128] |
bwd |
23088 |
14896 |
1.55 |
| MatrixDiagV3 |
float16 |
[1024 256] |
[1024 256 256] |
bwd |
45456 |
25973 |
1.75 |
| MatrixDiagV3 |
float16 |
[8 512] |
[8 512 512] |
bwd |
9488 |
5759 |
1.65 |
| MatrixDiagV3 |
float16 |
[8 1024] |
[8 1024 1024] |
bwd |
9856 |
6387 |
1.54 |
MatrixDiag float32
| Op_name |
dtype |
input_size |
output_size |
direction |
ROCm pytorch |
MIOpen HIP |
Improvement |
| MatrixDiagV3 |
float32 |
[512 2] |
[512 2 2] |
fwd |
36592 |
5950 |
6.15 |
| MatrixDiagV3 |
float32 |
[1024 4] |
[1024 4 4] |
fwd |
24928 |
6275 |
3.97 |
| MatrixDiagV3 |
float32 |
[128 8] |
[128 8 8] |
fwd |
24208 |
5166 |
4.69 |
| MatrixDiagV3 |
float32 |
[32 16] |
[32 16 16] |
fwd |
24096 |
4741 |
5.08 |
| MatrixDiagV3 |
float32 |
[256 16] |
[256 16 16] |
fwd |
29248 |
6244 |
4.68 |
| MatrixDiagV3 |
float32 |
[16 32] |
[16 32 32] |
fwd |
23440 |
4549 |
5.15 |
| MatrixDiagV3 |
float32 |
[128 64] |
[128 64 64] |
fwd |
27183 |
11893 |
2.29 |
| MatrixDiagV3 |
float32 |
[64 128] |
[64 128 128] |
fwd |
28624 |
18249 |
1.57 |
| MatrixDiagV3 |
float32 |
[2 512] |
[2 512 512] |
fwd |
26704 |
9139 |
2.92 |
| MatrixDiagV3 |
float32 |
[4 512] |
[4 512 512] |
fwd |
28272 |
15775 |
1.79 |
| MatrixDiagV3 |
float32 |
[1024 2] |
[1024 2 2] |
bwd |
8400 |
5632 |
1.49 |
| MatrixDiagV3 |
float32 |
[1024 4] |
[1024 4 4] |
bwd |
8464 |
5908 |
1.43 |
| MatrixDiagV3 |
float32 |
[1024 8] |
[1024 8 8] |
bwd |
8784 |
5803 |
1.51 |
| MatrixDiagV3 |
float32 |
[512 16] |
[512 16 16] |
bwd |
9168 |
5989 |
1.53 |
| MatrixDiagV3 |
float32 |
[64 32] |
[64 32 32] |
bwd |
9072 |
5693 |
1.59 |
| MatrixDiagV3 |
float32 |
[128 64] |
[128 64 64] |
bwd |
9696 |
5932 |
1.63 |
| MatrixDiagV3 |
float32 |
[1024 128] |
[1024 128 128] |
bwd |
21648 |
14860 |
1.46 |
| MatrixDiagV3 |
float32 |
[64 256] |
[64 256 256] |
bwd |
9808 |
6556 |
1.50 |
| MatrixDiagV3 |
float32 |
[32 512] |
[32 512 512] |
bwd |
9904 |
7350 |
1.35 |
| MatrixDiagV3 |
float32 |
[16 1024] |
[16 1024 1024] |
bwd |
10576 |
7815 |
1.35 |
MatrixDiag bfloat16
| Op_name |
dtype |
input_size |
output_size |
direction |
ROCm pytorch |
MIOpen HIP |
Improvement |
| MatrixDiagV3 |
bfloat16 |
[1024 2] |
[1024 2 2] |
fwd |
23968 |
5591 |
4.29 |
| MatrixDiagV3 |
bfloat16 |
[1024 4] |
[1024 4 4] |
fwd |
23536 |
6083 |
3.87 |
| MatrixDiagV3 |
bfloat16 |
[512 8] |
[512 8 8] |
fwd |
23376 |
6075 |
3.85 |
| MatrixDiagV3 |
bfloat16 |
[64 16] |
[64 16 16] |
fwd |
48000 |
4694 |
10.23 |
| MatrixDiagV3 |
bfloat16 |
[8 32] |
[8 32 32] |
fwd |
28000 |
4261 |
6.57 |
| MatrixDiagV3 |
bfloat16 |
[256 64] |
[256 64 64] |
fwd |
30000 |
20371 |
1.47 |
| MatrixDiagV3 |
bfloat16 |
[32 128] |
[32 128 128] |
fwd |
33664 |
11036 |
3.05 |
| MatrixDiagV3 |
bfloat16 |
[16 256] |
[16 256 256] |
fwd |
45136 |
17489 |
2.58 |
| MatrixDiagV3 |
bfloat16 |
[32 256] |
[32 256 256] |
fwd |
39952 |
31324 |
1.28 |
| MatrixDiagV3 |
bfloat16 |
[2 512] |
[2 512 512] |
fwd |
29712 |
9354 |
3.18 |
| MatrixDiagV3 |
bfloat16 |
[512 2] |
[512 2 2] |
bwd |
7824 |
4905 |
1.60 |
| MatrixDiagV3 |
bfloat16 |
[1024 4] |
[1024 4 4] |
bwd |
7952 |
5675 |
1.40 |
| MatrixDiagV3 |
bfloat16 |
[1024 8] |
[1024 8 8] |
bwd |
8400 |
5643 |
1.49 |
| MatrixDiagV3 |
bfloat16 |
[256 16] |
[256 16 16] |
bwd |
8400 |
5738 |
1.46 |
| MatrixDiagV3 |
bfloat16 |
[256 32] |
[256 32 32] |
bwd |
8768 |
5999 |
1.46 |
| MatrixDiagV3 |
bfloat16 |
[512 64] |
[512 64 64] |
bwd |
10400 |
7888 |
1.32 |
| MatrixDiagV3 |
bfloat16 |
[1024 128] |
[1024 128 128] |
bwd |
22512 |
14873 |
1.51 |
| MatrixDiagV3 |
bfloat16 |
[1024 256] |
[1024 256 256] |
bwd |
45648 |
25898 |
1.76 |
| MatrixDiagV3 |
bfloat16 |
[16 512] |
[16 512 512] |
bwd |
9536 |
6073 |
1.57 |
| MatrixDiagV3 |
bfloat16 |
[16 1024] |
[16 1024 1024] |
bwd |
10192 |
7576 |
1.35 |
MatrixSetDiag float16
| Op_name |
dtype |
input_size |
diag_size |
direction |
ROCm pytorch |
MIOpen HIP |
Improvement |
| MatrixSetDiagV3 |
float16 |
[256 2 2] |
[256 2] |
fwd |
48368 |
6112 |
7.91 |
| MatrixSetDiagV3 |
float16 |
[1024 4 4] |
[1024 4] |
fwd |
45280 |
6025 |
7.52 |
| MatrixSetDiagV3 |
float16 |
[128 8 8] |
[128 8] |
fwd |
51632 |
5700 |
9.06 |
| MatrixSetDiagV3 |
float16 |
[512 16 16] |
[512 16] |
fwd |
44320 |
6419 |
6.90 |
| MatrixSetDiagV3 |
float16 |
[128 32 32] |
[128 32] |
fwd |
36944 |
6490 |
5.69 |
| MatrixSetDiagV3 |
float16 |
[16 64 64] |
[16 64] |
fwd |
37104 |
6235 |
5.95 |
| MatrixSetDiagV3 |
float16 |
[4 128 128] |
[4 128] |
fwd |
35823 |
6054 |
5.92 |
| MatrixSetDiagV3 |
float16 |
[4 256 256] |
[4 256] |
fwd |
27824 |
9779 |
2.85 |
| MatrixSetDiagV3 |
float16 |
[8 256 256] |
[8 256] |
fwd |
29552 |
11979 |
2.47 |
| MatrixSetDiagV3 |
float16 |
[2 512 512] |
[2 512] |
fwd |
32800 |
11326 |
2.90 |
| MatrixSetDiagV3 |
float16 |
[128 2 2] |
[128 2] |
bwd |
190734 |
20290 |
9.40 |
| MatrixSetDiagV3 |
float16 |
[1024 4 4] |
[1024 4] |
bwd |
203006 |
23807 |
8.53 |
| MatrixSetDiagV3 |
float16 |
[1024 8 8] |
[1024 8] |
bwd |
153311 |
18052 |
8.49 |
| MatrixSetDiagV3 |
float16 |
[512 16 16] |
[512 16] |
bwd |
153647 |
20780 |
7.39 |
| MatrixSetDiagV3 |
float16 |
[512 32 32] |
[512 32] |
bwd |
120255 |
25044 |
4.80 |
| MatrixSetDiagV3 |
float16 |
[64 64 64] |
[64 64] |
bwd |
135743 |
21401 |
6.34 |
| MatrixSetDiagV3 |
float16 |
[64 128 128] |
[64 128] |
bwd |
134079 |
33490 |
4.00 |
| MatrixSetDiagV3 |
float16 |
[16 256 256] |
[16 256] |
bwd |
154383 |
32471 |
4.75 |
| MatrixSetDiagV3 |
float16 |
[8 512 512] |
[8 512] |
bwd |
136111 |
43920 |
3.10 |
| MatrixSetDiagV3 |
float16 |
[4 1024 1024] |
[4 1024] |
bwd |
156191 |
68072 |
2.29 |
MatrixSetDiag float32
| Op_name |
dtype |
input_size |
diag_size |
direction |
ROCm pytorch |
MIOpen HIP |
Improvement |
| MatrixSetDiagV3 |
float32 |
[1024 2 2] |
[1024 2] |
fwd |
28512 |
5879 |
4.85 |
| MatrixSetDiagV3 |
float32 |
[512 4 4] |
[512 4] |
fwd |
29104 |
5763 |
5.05 |
| MatrixSetDiagV3 |
float32 |
[1024 8 8] |
[1024 8] |
fwd |
29712 |
6067 |
4.90 |
| MatrixSetDiagV3 |
float32 |
[512 16 16] |
[512 16] |
fwd |
30064 |
6170 |
4.87 |
| MatrixSetDiagV3 |
float32 |
[128 32 32] |
[128 32] |
fwd |
29216 |
6312 |
4.63 |
| MatrixSetDiagV3 |
float32 |
[32 64 64] |
[32 64] |
fwd |
33680 |
6155 |
5.47 |
| MatrixSetDiagV3 |
float32 |
[8 128 128] |
[8 128] |
fwd |
28800 |
6094 |
4.73 |
| MatrixSetDiagV3 |
float32 |
[2 256 256] |
[2 256] |
fwd |
28432 |
5976 |
4.76 |
| MatrixSetDiagV3 |
float32 |
[1024 2 2] |
[1024 2] |
bwd |
162623 |
21037 |
7.73 |
| MatrixSetDiagV3 |
float32 |
[64 4 4] |
[64 4] |
bwd |
141951 |
17181 |
8.26 |
| MatrixSetDiagV3 |
float32 |
[512 8 8] |
[512 8] |
bwd |
136095 |
18094 |
7.52 |
| MatrixSetDiagV3 |
float32 |
[1024 16 16] |
[1024 16] |
bwd |
134687 |
22277 |
6.05 |
| MatrixSetDiagV3 |
float32 |
[512 32 32] |
[512 32] |
bwd |
134271 |
25808 |
5.20 |
| MatrixSetDiagV3 |
float32 |
[128 64 64] |
[128 64] |
bwd |
172558 |
29060 |
5.94 |
| MatrixSetDiagV3 |
float32 |
[16 128 128] |
[16 128] |
bwd |
151167 |
22422 |
6.74 |
| MatrixSetDiagV3 |
float32 |
[2 256 256] |
[2 256] |
bwd |
142543 |
22317 |
6.39 |
| MatrixSetDiagV3 |
float32 |
[8 512 512] |
[8 512] |
bwd |
144399 |
45533 |
3.17 |
| MatrixSetDiagV3 |
float32 |
[8 1024 1024] |
[8 1024] |
bwd |
267102 |
127155 |
2.10 |
MatrixSetDiag bfloat16
| Op_name |
dtype |
input_size |
diag_size |
direction |
ROCm pytorch |
MIOpen HIP |
Improvement |
| MatrixSetDiagV3 |
bfloat16 |
[1024 2 2] |
[1024 2] |
fwd |
28976 |
5785 |
5.01 |
| MatrixSetDiagV3 |
bfloat16 |
[1024 4 4] |
[1024 4] |
fwd |
30384 |
5787 |
5.25 |
| MatrixSetDiagV3 |
bfloat16 |
[512 8 8] |
[512 8] |
fwd |
30128 |
5664 |
5.32 |
| MatrixSetDiagV3 |
bfloat16 |
[512 16 16] |
[512 16] |
fwd |
29168 |
6079 |
4.80 |
| MatrixSetDiagV3 |
bfloat16 |
[256 32 32] |
[256 32] |
fwd |
30320 |
9145 |
3.32 |
| MatrixSetDiagV3 |
bfloat16 |
[64 64 64] |
[64 64] |
fwd |
29744 |
9027 |
3.30 |
| MatrixSetDiagV3 |
bfloat16 |
[32 128 128] |
[32 128] |
fwd |
31056 |
11885 |
2.61 |
| MatrixSetDiagV3 |
bfloat16 |
[4 256 256] |
[4 256] |
fwd |
27922 |
8523 |
3.28 |
| MatrixSetDiagV3 |
bfloat16 |
[8 256 256] |
[8 256] |
fwd |
29856 |
11547 |
2.59 |
| MatrixSetDiagV3 |
bfloat16 |
[2 512 512] |
[2 512] |
fwd |
29296 |
10915 |
2.68 |
| MatrixSetDiagV3 |
bfloat16 |
[1024 2 2] |
[1024 2] |
bwd |
110719 |
19820 |
5.59 |
| MatrixSetDiagV3 |
bfloat16 |
[512 4 4] |
[512 4] |
bwd |
136687 |
18999 |
7.19 |
| MatrixSetDiagV3 |
bfloat16 |
[1024 8 8] |
[1024 8] |
bwd |
107055 |
17357 |
6.17 |
| MatrixSetDiagV3 |
bfloat16 |
[1024 16 16] |
[1024 16] |
bwd |
96287 |
23717 |
4.06 |
| MatrixSetDiagV3 |
bfloat16 |
[128 32 32] |
[128 32] |
bwd |
102703 |
18238 |
5.63 |
| MatrixSetDiagV3 |
bfloat16 |
[32 64 64] |
[32 64] |
bwd |
105007 |
17220 |
6.10 |
| MatrixSetDiagV3 |
bfloat16 |
[64 128 128] |
[64 128] |
bwd |
117007 |
33493 |
3.49 |
| MatrixSetDiagV3 |
bfloat16 |
[32 256 256] |
[32 256] |
bwd |
112991 |
48159 |
2.35 |
| MatrixSetDiagV3 |
bfloat16 |
[2 512 512] |
[2 512] |
bwd |
123359 |
27704 |
4.45 |
| MatrixSetDiagV3 |
bfloat16 |
[4 1024 1024] |
[4 1024] |
bwd |
149454 |
90338 |
1.65 |
MatrixDiagPart float16
| Op_name |
dtype |
input_size |
output_size |
direction |
ROCm pytorch |
MIOpen HIP |
Improvement |
| MatrixDiagPartV2 |
float16 |
[1024 4 4] |
[1024 4] |
fwd |
8128 |
6361 |
1.28 |
| MatrixDiagPartV2 |
float16 |
[1024 8 8] |
[1024 8] |
fwd |
8112 |
6214 |
1.31 |
| MatrixDiagPartV2 |
float16 |
[256 16 16] |
[256 16] |
fwd |
27148 |
6175 |
4.40 |
| MatrixDiagPartV2 |
float16 |
[1024 16 16] |
[1024 16] |
fwd |
16841 |
6445 |
2.61 |
| MatrixDiagPartV2 |
float16 |
[1024 32 32] |
[1024 32] |
fwd |
27019 |
8140 |
3.32 |
| MatrixDiagPartV2 |
float16 |
[512 64 64] |
[512 64] |
fwd |
51552 |
8591 |
6.00 |
| MatrixDiagPartV2 |
float16 |
[64 128 128] |
[64 128] |
fwd |
16782 |
6627 |
2.53 |
| MatrixDiagPartV2 |
float16 |
[4 512 512] |
[4 512] |
fwd |
9040 |
6236 |
1.45 |
| MatrixDiagPartV2 |
float16 |
[128 512 512] |
[128 512] |
fwd |
38311 |
12350 |
3.10 |
| MatrixDiagPartV2 |
float16 |
[16 1024 1024] |
[16 1024] |
fwd |
10544 |
8147 |
1.29 |
| MatrixDiagPartV2 |
float16 |
[512 2 2] |
[512 2] |
bwd |
26208 |
4475 |
5.86 |
| MatrixDiagPartV2 |
float16 |
[128 4 4] |
[128 4] |
bwd |
23936 |
4240 |
5.65 |
| MatrixDiagPartV2 |
float16 |
[128 8 8] |
[128 8] |
bwd |
24560 |
4164 |
5.90 |
| MatrixDiagPartV2 |
float16 |
[256 16 16] |
[256 16] |
bwd |
36688 |
5732 |
6.40 |
| MatrixDiagPartV2 |
float16 |
[16 32 32] |
[16 32] |
bwd |
39408 |
3751 |
10.51 |
| MatrixDiagPartV2 |
float16 |
[128 64 64] |
[128 64] |
bwd |
34624 |
12522 |
2.77 |
| MatrixDiagPartV2 |
float16 |
[64 128 128] |
[64 128] |
bwd |
32591 |
17250 |
1.89 |
| MatrixDiagPartV2 |
float16 |
[16 256 256] |
[16 256] |
bwd |
32800 |
15927 |
2.06 |
| MatrixDiagPartV2 |
float16 |
[2 512 512] |
[2 512] |
bwd |
31984 |
8319 |
3.84 |
| MatrixDiagPartV2 |
float16 |
[2 1024 1024] |
[2 1024] |
bwd |
33664 |
23881 |
1.41 |
MatrixDiagPart float32
| Op_name |
dtype |
input_size |
output_size |
direction |
ROCm pytorch |
MIOpen HIP |
Improvement |
| MatrixDiagPartV2 |
float32 |
[1024 2 2] |
[1024 2] |
fwd |
8912 |
6238 |
1.43 |
| MatrixDiagPartV2 |
float32 |
[1024 4 4] |
[1024 4] |
fwd |
8096 |
6062 |
1.34 |
| MatrixDiagPartV2 |
float32 |
[256 8 8] |
[256 8] |
fwd |
8240 |
6060 |
1.36 |
| MatrixDiagPartV2 |
float32 |
[512 16 16] |
[512 16] |
fwd |
8944 |
6555 |
1.36 |
| MatrixDiagPartV2 |
float32 |
[512 32 32] |
[512 32] |
fwd |
8928 |
6985 |
1.28 |
| MatrixDiagPartV2 |
float32 |
[128 64 64] |
[128 64] |
fwd |
8896 |
6553 |
1.36 |
| MatrixDiagPartV2 |
float32 |
[4 128 128] |
[4 128] |
fwd |
12064 |
8001 |
1.51 |
| MatrixDiagPartV2 |
float32 |
[8 256 256] |
[8 256] |
fwd |
8608 |
5336 |
1.61 |
| MatrixDiagPartV2 |
float32 |
[4 512 512] |
[4 512] |
fwd |
9424 |
6301 |
1.50 |
| MatrixDiagPartV2 |
float32 |
[8 1024 1024] |
[8 1024] |
fwd |
10272 |
7243 |
1.42 |
| MatrixDiagPartV2 |
float32 |
[1024 2 2] |
[1024 2] |
bwd |
24848 |
5968 |
4.16 |
| MatrixDiagPartV2 |
float32 |
[128 4 4] |
[128 4] |
bwd |
23120 |
4525 |
5.11 |
| MatrixDiagPartV2 |
float32 |
[2 8 8] |
[2 8] |
bwd |
21504 |
6123 |
3.51 |
| MatrixDiagPartV2 |
float32 |
[4 8 8] |
[4 8] |
bwd |
20320 |
6366 |
3.19 |
| MatrixDiagPartV2 |
float32 |
[64 8 8] |
[64 8] |
bwd |
23360 |
4193 |
5.57 |
| MatrixDiagPartV2 |
float32 |
[32 16 16] |
[32 16] |
bwd |
23008 |
4127 |
5.57 |
| MatrixDiagPartV2 |
float32 |
[2 32 32] |
[2 32] |
bwd |
26560 |
4008 |
6.63 |
| MatrixDiagPartV2 |
float32 |
[512 64 64] |
[512 64] |
bwd |
57055 |
34863 |
1.64 |
| MatrixDiagPartV2 |
float32 |
[2 128 128] |
[2 128] |
bwd |
52672 |
3785 |
13.92 |
| MatrixDiagPartV2 |
float32 |
[32 256 256] |
[32 256] |
bwd |
64000 |
28904 |
2.21 |
| MatrixDiagPartV2 |
float32 |
[4 1024 1024] |
[4 1024] |
bwd |
59456 |
48589 |
1.22 |
MatrixDiagPart bfloat16
| Op_name |
dtype |
input_size |
output_size |
direction |
ROCm pytorch |
MIOpen HIP |
Improvement |
| MatrixDiagPartV2 |
bfloat16 |
[8 256 256] |
[8 256] |
fwd |
6960 |
5627 |
1.24 |
| MatrixDiagPartV2 |
bfloat16 |
[16 256 256] |
[16 256] |
fwd |
6832 |
5663 |
1.21 |
| MatrixDiagPartV2 |
bfloat16 |
[2 512 512] |
[2 512] |
fwd |
8704 |
6964 |
1.25 |
| MatrixDiagPartV2 |
bfloat16 |
[4 512 512] |
[4 512] |
fwd |
8976 |
7006 |
1.28 |
| MatrixDiagPartV2 |
bfloat16 |
[8 512 512] |
[8 512] |
fwd |
8096 |
6513 |
1.24 |
| MatrixDiagPartV2 |
bfloat16 |
[16 512 512] |
[16 512] |
fwd |
8816 |
6865 |
1.28 |
| MatrixDiagPartV2 |
bfloat16 |
[2 1024 1024] |
[2 1024] |
fwd |
8944 |
6433 |
1.39 |
| MatrixDiagPartV2 |
bfloat16 |
[4 1024 1024] |
[4 1024] |
fwd |
8960 |
6608 |
1.36 |
| MatrixDiagPartV2 |
bfloat16 |
[8 1024 1024] |
[8 1024] |
fwd |
9136 |
7138 |
1.28 |
| MatrixDiagPartV2 |
bfloat16 |
[512 4 4] |
[512 4] |
bwd |
23776 |
4601 |
5.17 |
| MatrixDiagPartV2 |
bfloat16 |
[1024 4 4] |
[1024 4] |
bwd |
28128 |
5604 |
5.02 |
| MatrixDiagPartV2 |
bfloat16 |
[512 8 8] |
[512 8] |
bwd |
23616 |
5659 |
4.17 |
| MatrixDiagPartV2 |
bfloat16 |
[1024 8 8] |
[1024 8] |
bwd |
27792 |
5688 |
4.89 |
| MatrixDiagPartV2 |
bfloat16 |
[1024 32 32] |
[1024 32] |
bwd |
27360 |
19128 |
1.43 |
| MatrixDiagPartV2 |
bfloat16 |
[16 128 128] |
[16 128] |
bwd |
24512 |
6574 |
3.73 |
| MatrixDiagPartV2 |
bfloat16 |
[32 128 128] |
[32 128] |
bwd |
26352 |
10118 |
2.60 |
| MatrixDiagPartV2 |
bfloat16 |
[8 256 256] |
[8 256] |
bwd |
33040 |
9427 |
3.50 |
| MatrixDiagPartV2 |
bfloat16 |
[4 512 512] |
[4 512] |
bwd |
27216 |
14506 |
1.88 |
| MatrixDiagPartV2 |
bfloat16 |
[2 1024 1024] |
[2 1024] |
bwd |
37472 |
24024 |
1.56 |
| op_name |
type |
average |
| MatrixDiag |
float16 |
3.02 |
| MatrixDiag |
float32 |
2.78 |
| MatrixDiag |
bfloat16 |
2.97 |
| MatrixSetDiag |
float16 |
5.33 |
| MatrixSetDiag |
float32 |
4.82 |
| MatrixSetDiag |
bfloat16 |
4.24 |
| MatrixDiagPart |
float16 |
3.65 |
| MatrixDiagPart |
float32 |
3.33 |
| MatrixDiagPart |
bfloat16 |
3.23 |