Implement CosineSimilarity
Open
cognaiger9
opened this issue 10 months ago
•
0 comments
- Add CosineSimilarity operation with forward and backward kernels.
- Add driver and gtest for kernels.
- MIOpen performs better if:
- Number of output elements exceeds 20000
Average improvement over ROCm
| type |
fwd |
bwd |
| float16 |
1.86 |
2.72 |
| float |
1.69 |
1.82 |
| bfloat16 |
2.04 |
1.96 |
Detail Benchmark
float16
| op_name |
dtype |
input1 |
input2 |
contiguous |
direction |
ROCm |
MIOpen |
MIOpen vs ROCm |
| Cosine Similarities |
float16 |
[100 512 200] |
[100 512 200] |
contiguous |
fwd |
426253 |
347020 |
1.23 |
| Cosine Similarities |
float16 |
[100 512 200] |
[100 512 200] |
contiguous |
bwd |
1462341 |
947513 |
1.54 |
| Cosine Similarities |
float16 |
[100 1024 200] |
[100 1024 200] |
contiguous |
fwd |
798794 |
688528 |
1.16 |
| Cosine Similarities |
float16 |
[100 1024 200] |
[100 1024 200] |
contiguous |
bwd |
2784315 |
1897090 |
1.47 |
| Cosine Similarities |
float16 |
[48 512 512] |
[48 512 512] |
contiguous |
fwd |
515276 |
352034 |
1.46 |
| Cosine Similarities |
float16 |
[48 512 512] |
[48 512 512] |
contiguous |
bwd |
1786899 |
966217 |
1.85 |
| Cosine Similarities |
float16 |
[48 512 512] |
[48 512 512] |
noncontiguous |
fwd |
2821500 |
1338300 |
2.11 |
| Cosine Similarities |
float16 |
[48 512 512] |
[48 512 512] |
noncontiguous |
bwd |
6846559 |
4324440 |
1.58 |
| Cosine Similarities |
float16 |
[48 1024 512] |
[48 1024 512] |
noncontiguous |
fwd |
5630840 |
3726720 |
1.51 |
| Cosine Similarities |
float16 |
[48 1024 512] |
[48 1024 512] |
noncontiguous |
bwd |
14110539 |
10911400 |
1.29 |
| Cosine Similarities |
float16 |
[48 2048 512] |
[48 2048 512] |
contiguous |
fwd |
2693597 |
1422960 |
1.89 |
| Cosine Similarities |
float16 |
[48 2048 512] |
[48 2048 512] |
contiguous |
bwd |
7302252 |
3825210 |
1.91 |
| Cosine Similarities |
float16 |
[48 2048 512] |
[48 2048 512] |
noncontiguous |
fwd |
10481655 |
8927840 |
1.17 |
| Cosine Similarities |
float16 |
[48 2048 512] |
[48 2048 512] |
noncontiguous |
bwd |
33497829 |
27936900 |
1.20 |
| Cosine Similarities |
float16 |
[8192 512 8] |
[8192 512 8] |
contiguous |
fwd |
1855731 |
760218 |
2.44 |
| Cosine Similarities |
float16 |
[8192 512 8] |
[8192 512 8] |
contiguous |
bwd |
4881614 |
3379450 |
1.44 |
| Cosine Similarities |
float16 |
[8192 512 8] |
[8192 512 8] |
noncontiguous |
fwd |
3529784 |
890530 |
3.96 |
| Cosine Similarities |
float16 |
[8192 512 8] |
[8192 512 8] |
noncontiguous |
bwd |
10447912 |
3381180 |
3.09 |
| Cosine Similarities |
float16 |
[8192 512 16] |
[8192 512 16] |
contiguous |
fwd |
2646878 |
1552040 |
1.71 |
| Cosine Similarities |
float16 |
[8192 512 16] |
[8192 512 16] |
contiguous |
bwd |
8911075 |
4909360 |
1.82 |
float32
| op_name |
dtype |
input1 |
input2 |
contiguous |
direction |
ROCm |
MIOpen |
MIOpen vs ROCm |
| Cosine Similarities |
float32 |
[100 512 200] |
[100 512 200] |
contiguous |
fwd |
577659 |
370984 |
1.56 |
| Cosine Similarities |
float32 |
[100 512 200] |
[100 512 200] |
contiguous |
bwd |
1849378 |
1052470 |
1.76 |
| Cosine Similarities |
float32 |
[100 1024 200] |
[100 1024 200] |
contiguous |
fwd |
1136999 |
746626 |
1.52 |
| Cosine Similarities |
float32 |
[100 1024 200] |
[100 1024 200] |
contiguous |
bwd |
3618709 |
2147220 |
1.69 |
| Cosine Similarities |
float32 |
[100 2048 200] |
[100 2048 200] |
contiguous |
fwd |
2632380 |
1524860 |
1.73 |
| Cosine Similarities |
float32 |
[100 2048 200] |
[100 2048 200] |
contiguous |
bwd |
8504242 |
4356880 |
1.95 |
| Cosine Similarities |
float32 |
[48 512 512] |
[48 512 512] |
contiguous |
fwd |
837514 |
383944 |
2.18 |
| Cosine Similarities |
float32 |
[48 512 512] |
[48 512 512] |
contiguous |
bwd |
2348079 |
1042270 |
2.25 |
| Cosine Similarities |
float32 |
[48 512 512] |
[48 512 512] |
noncontiguous |
fwd |
3506695 |
1811310 |
1.94 |
| Cosine Similarities |
float32 |
[48 512 512] |
[48 512 512] |
noncontiguous |
bwd |
9166318 |
5483190 |
1.67 |
| Cosine Similarities |
float32 |
[48 1024 512] |
[48 1024 512] |
contiguous |
fwd |
1160312 |
738841 |
1.57 |
| Cosine Similarities |
float32 |
[48 1024 512] |
[48 1024 512] |
contiguous |
bwd |
4295953 |
2045450 |
2.10 |
| Cosine Similarities |
float32 |
[48 1024 512] |
[48 1024 512] |
noncontiguous |
fwd |
6700240 |
5319740 |
1.26 |
| Cosine Similarities |
float32 |
[48 1024 512] |
[48 1024 512] |
noncontiguous |
bwd |
19835299 |
14577400 |
1.36 |
| Cosine Similarities |
float32 |
[48 2048 512] |
[48 2048 512] |
contiguous |
fwd |
2119473 |
1474480 |
1.44 |
| Cosine Similarities |
float32 |
[48 2048 512] |
[48 2048 512] |
contiguous |
bwd |
8352629 |
4055320 |
2.06 |
| Cosine Similarities |
float32 |
[48 2048 512] |
[48 2048 512] |
noncontiguous |
fwd |
12393130 |
9287640 |
1.33 |
| Cosine Similarities |
float32 |
[48 2048 512] |
[48 2048 512] |
noncontiguous |
bwd |
45345268 |
31751600 |
1.43 |
| Cosine Similarities |
float32 |
[8192 512 8] |
[8192 512 8] |
contiguous |
fwd |
2172193 |
903117 |
2.41 |
| Cosine Similarities |
float32 |
[8192 512 8] |
[8192 512 8] |
contiguous |
bwd |
6099446 |
3222970 |
1.89 |
bfloat16
| op_name |
dtype |
input1 |
input2 |
contiguous |
direction |
ROCm |
MIOpen |
MIOpen vs ROCm |
| Cosine Similarities |
bfloat16 |
[100 1024 200] |
[100 1024 200] |
contiguous |
fwd |
846298 |
693346 |
1.22 |
| Cosine Similarities |
bfloat16 |
[100 1024 200] |
[100 1024 200] |
contiguous |
bwd |
2977722 |
1905840 |
1.56 |
| Cosine Similarities |
bfloat16 |
[48 512 512] |
[48 512 512] |
contiguous |
fwd |
520316 |
354896 |
1.47 |
| Cosine Similarities |
bfloat16 |
[48 512 512] |
[48 512 512] |
contiguous |
bwd |
1880802 |
1017110 |
1.85 |
| Cosine Similarities |
bfloat16 |
[48 512 512] |
[48 512 512] |
noncontiguous |
fwd |
2847387 |
1336790 |
2.13 |
| Cosine Similarities |
bfloat16 |
[48 512 512] |
[48 512 512] |
noncontiguous |
bwd |
6875854 |
4354790 |
1.58 |
| Cosine Similarities |
bfloat16 |
[48 1024 512] |
[48 1024 512] |
contiguous |
fwd |
850234 |
740263 |
1.15 |
| Cosine Similarities |
bfloat16 |
[48 1024 512] |
[48 1024 512] |
contiguous |
bwd |
3515447 |
2027460 |
1.73 |
| Cosine Similarities |
bfloat16 |
[48 1024 512] |
[48 1024 512] |
noncontiguous |
fwd |
6123076 |
3721010 |
1.65 |
| Cosine Similarities |
bfloat16 |
[48 1024 512] |
[48 1024 512] |
noncontiguous |
bwd |
14947030 |
13634400 |
1.10 |
| Cosine Similarities |
bfloat16 |
[8192 512 8] |
[8192 512 8] |
contiguous |
fwd |
1918707 |
745036 |
2.58 |
| Cosine Similarities |
bfloat16 |
[8192 512 8] |
[8192 512 8] |
contiguous |
bwd |
5203020 |
3372960 |
1.54 |
| Cosine Similarities |
bfloat16 |
[8192 512 8] |
[8192 512 8] |
noncontiguous |
fwd |
3599208 |
862476 |
4.17 |
| Cosine Similarities |
bfloat16 |
[8192 512 8] |
[8192 512 8] |
noncontiguous |
bwd |
11091124 |
3497940 |
3.17 |
| Cosine Similarities |
bfloat16 |
[8192 512 16] |
[8192 512 16] |
contiguous |
fwd |
2800925 |
1561570 |
1.79 |
| Cosine Similarities |
bfloat16 |
[8192 512 16] |
[8192 512 16] |
contiguous |
bwd |
9585856 |
4725360 |
2.03 |
| Cosine Similarities |
bfloat16 |
[8192 512 32] |
[8192 512 32] |
contiguous |
fwd |
4520643 |
2975230 |
1.52 |
| Cosine Similarities |
bfloat16 |
[8192 512 32] |
[8192 512 32] |
contiguous |
bwd |
18413546 |
8472450 |
2.17 |
| Cosine Similarities |
bfloat16 |
[4096 512 8] |
[4096 512 8] |
contiguous |
fwd |
1258264 |
469138 |
2.68 |
| Cosine Similarities |
bfloat16 |
[4096 512 8] |
[4096 512 8] |
contiguous |
bwd |
3423051 |
1197620 |
2.86 |