Implement embedding
Open
cognaiger9
opened this issue 10 months ago
•
0 comments
- Add Embedding operation with backward kernels.
- Add driver and gtest for kernels.
- MIOpen performs better if:
- Split dimension is 0
- Number of elements in input tensor are less than 400,000 (forward case) or less than 800,000 (backward case)
Average improvement over ROCm
| type |
bwd |
| float16 |
2.1 |
| float |
2.75 |
| bfloat16 |
2.41 |
Detail Benchmark
float16
| op_name |
dtype |
input size |
weight size |
contiguous |
direction |
ROCm |
MIOpen |
MIOpen vs ROCm |
| Embedding |
float16 |
[40 512] |
[1026 1024] |
cont |
bwd |
600548 |
423901 |
1.42 |
| Embedding |
float16 |
[40 512] |
[1026 1024] |
noncont |
bwd |
579166 |
682325 |
0.85 |
| Embedding |
float16 |
[40 512] |
[50000 768] |
cont |
bwd |
893821 |
466886 |
1.91 |
| Embedding |
float16 |
[40 512] |
[50000 768] |
noncont |
bwd |
916582 |
695751 |
1.32 |
| Embedding |
float16 |
[40 512] |
[2 768] |
cont |
bwd |
897141 |
180925 |
4.96 |
| Embedding |
float16 |
[40 512] |
[2 768] |
noncont |
bwd |
931511 |
488879 |
1.91 |
| Embedding |
float16 |
[48 512] |
[52100 512] |
cont |
bwd |
798757 |
356038 |
2.24 |
| Embedding |
float16 |
[48 512] |
[52100 512] |
noncont |
bwd |
886080 |
540652 |
1.64 |
| Embedding |
float16 |
[32 512] |
[514 768] |
cont |
bwd |
430180 |
274782 |
1.57 |
| Embedding |
float16 |
[32 512] |
[514 768] |
noncont |
bwd |
511610 |
420868 |
1.22 |
| Embedding |
float16 |
[32 188] |
[148 512] |
cont |
bwd |
434340 |
140962 |
3.08 |
| Embedding |
float16 |
[32 188] |
[148 512] |
noncont |
bwd |
392898 |
138703 |
2.83 |
| Embedding |
float16 |
[16 512] |
[512 768] |
cont |
bwd |
373517 |
150829 |
2.48 |
| Embedding |
float16 |
[16 512] |
[512 768] |
noncont |
bwd |
486602 |
225576 |
2.16 |
| Embedding |
float16 |
[16 512] |
[52000 768] |
cont |
bwd |
672631 |
317764 |
2.12 |
| Embedding |
float16 |
[16 512] |
[52000 768] |
noncont |
bwd |
650389 |
408122 |
1.59 |
| Embedding |
float16 |
[16 1024] |
[52000 768] |
cont |
bwd |
753374 |
421790 |
1.79 |
| Embedding |
float16 |
[16 1024] |
[52000 768] |
noncont |
bwd |
848298 |
603588 |
1.41 |
| Embedding |
float16 |
[2 1024] |
[1024 768] |
cont |
bwd |
172588 |
57119 |
3.02 |
| Embedding |
float16 |
[2 1024] |
[1024 768] |
noncont |
bwd |
194449 |
77961 |
2.49 |
float32
| op_name |
dtype |
input size |
weight size |
contiguous |
direction |
ROCm |
MIOpen |
MIOpen vs ROCm |
| Embedding |
float32 |
[40 512] |
[1026 1024] |
cont |
bwd |
515964 |
260471 |
1.98 |
| Embedding |
float32 |
[40 512] |
[1026 1024] |
noncont |
bwd |
587787 |
636332 |
0.92 |
| Embedding |
float32 |
[40 512] |
[50000 768] |
cont |
bwd |
945843 |
414779 |
2.28 |
| Embedding |
float32 |
[40 512] |
[50000 768] |
noncont |
bwd |
1030907 |
662515 |
1.56 |
| Embedding |
float32 |
[40 512] |
[2 768] |
cont |
bwd |
849659 |
182595 |
4.65 |
| Embedding |
float32 |
[40 512] |
[2 768] |
noncont |
bwd |
902241 |
493036 |
1.83 |
| Embedding |
float32 |
[48 512] |
[52100 512] |
cont |
bwd |
804137 |
314753 |
2.55 |
| Embedding |
float32 |
[48 512] |
[52100 512] |
noncont |
bwd |
969344 |
514590 |
1.88 |
| Embedding |
float32 |
[32 512] |
[514 768] |
cont |
bwd |
461401 |
159905 |
2.89 |
| Embedding |
float32 |
[32 512] |
[514 768] |
noncont |
bwd |
452040 |
391857 |
1.15 |
| Embedding |
float32 |
[32 188] |
[148 512] |
cont |
bwd |
361496 |
65447 |
5.52 |
| Embedding |
float32 |
[32 188] |
[148 512] |
noncont |
bwd |
889040 |
150777 |
5.90 |
| Embedding |
float32 |
[16 512] |
[512 768] |
cont |
bwd |
449706 |
93224 |
4.82 |
| Embedding |
float32 |
[16 512] |
[512 768] |
noncont |
bwd |
479844 |
210369 |
2.28 |
| Embedding |
float32 |
[16 512] |
[52000 768] |
cont |
bwd |
712272 |
296758 |
2.40 |
| Embedding |
float32 |
[16 512] |
[52000 768] |
noncont |
bwd |
838698 |
395938 |
2.12 |
| Embedding |
float32 |
[16 1024] |
[52000 768] |
cont |
bwd |
802416 |
381329 |
2.10 |
| Embedding |
float32 |
[16 1024] |
[52000 768] |
noncont |
bwd |
878280 |
577851 |
1.52 |
| Embedding |
float32 |
[2 1024] |
[1024 768] |
cont |
bwd |
177068 |
45560 |
3.89 |
| Embedding |
float32 |
[2 1024] |
[1024 768] |
noncont |
bwd |
210750 |
75880 |
2.78 |
bfloat16
| op_name |
dtype |
input size |
weight size |
contiguous |
direction |
ROCm |
MIOpen |
MIOpen vs ROCm |
| Embedding |
bfloat16 |
[40 512] |
[1026 1024] |
cont |
bwd |
469609 |
319702 |
1.47 |
| Embedding |
bfloat16 |
[40 512] |
[1026 1024] |
noncont |
bwd |
574466 |
664567 |
0.86 |
| Embedding |
bfloat16 |
[40 512] |
[50000 768] |
cont |
bwd |
924102 |
427295 |
2.16 |
| Embedding |
bfloat16 |
[40 512] |
[50000 768] |
noncont |
bwd |
970384 |
683101 |
1.42 |
| Embedding |
bfloat16 |
[40 512] |
[2 768] |
cont |
bwd |
845138 |
178876 |
4.72 |
| Embedding |
bfloat16 |
[40 512] |
[2 768] |
noncont |
bwd |
879140 |
492015 |
1.79 |
| Embedding |
bfloat16 |
[48 512] |
[52100 512] |
cont |
bwd |
802816 |
324140 |
2.48 |
| Embedding |
bfloat16 |
[48 512] |
[52100 512] |
noncont |
bwd |
965304 |
530431 |
1.82 |
| Embedding |
bfloat16 |
[32 512] |
[514 768] |
cont |
bwd |
381746 |
204544 |
1.87 |
| Embedding |
bfloat16 |
[32 512] |
[514 768] |
noncont |
bwd |
1047968 |
411612 |
2.55 |
| Embedding |
bfloat16 |
[32 188] |
[148 512] |
cont |
bwd |
370737 |
72916 |
5.08 |
| Embedding |
bfloat16 |
[32 188] |
[148 512] |
noncont |
bwd |
376197 |
122161 |
3.08 |
| Embedding |
bfloat16 |
[16 512] |
[512 768] |
cont |
bwd |
365477 |
115845 |
3.15 |
| Embedding |
bfloat16 |
[16 512] |
[512 768] |
noncont |
bwd |
351416 |
218495 |
1.61 |
| Embedding |
bfloat16 |
[16 512] |
[52000 768] |
cont |
bwd |
688691 |
301789 |
2.28 |
| Embedding |
bfloat16 |
[16 512] |
[52000 768] |
noncont |
bwd |
773815 |
402798 |
1.92 |
| Embedding |
bfloat16 |
[16 1024] |
[52000 768] |
cont |
bwd |
805536 |
389987 |
2.07 |
| Embedding |
bfloat16 |
[16 1024] |
[52000 768] |
noncont |
bwd |
818937 |
594882 |
1.38 |
| Embedding |
bfloat16 |
[2 1024] |
[1024 768] |
cont |
bwd |
185588 |
49561 |
3.74 |
| Embedding |
bfloat16 |
[2 1024] |
[1024 768] |
noncont |
bwd |
215590 |
75879 |
2.84 |