Implement EmbeddingBag
Open
cognaiger9
opened this issue 10 months ago
•
0 comments
- Add EmbeddingBag operation with forward kernels.
- Add driver and gtest for kernels.
- MIOpen performs better if:
- Mode: Max
- Mode: Mean or Sum, when the tensor type is float and all tensors are contiguous and number of elements in the output exceeds 2^19
- Note:
- Forward solver only works with 2D tensor
Average improvement over ROCm
| type |
fwd |
| float16 |
1.38 |
| float |
1.3 |
| bfloat16 |
1.4 |
Detail Benchmark
float16 (mode max)
| op_name |
dtype |
input size |
weight size |
cont |
direction |
ROCm |
MIOpen |
Improvement |
| EmbeddingBag |
float16 |
[32 655] |
[100 256] |
cont |
fwd |
320271 |
253759 |
1.26 |
| EmbeddingBag |
float16 |
[512 512] |
[256 1024] |
cont |
fwd |
1198653 |
846078 |
1.42 |
| EmbeddingBag |
float16 |
[512 512] |
[256 1024] |
noncont |
fwd |
1210990 |
1012089 |
1.20 |
| EmbeddingBag |
float16 |
[512 512] |
[512 2048] |
cont |
fwd |
2385852 |
1520730 |
1.57 |
| EmbeddingBag |
float16 |
[512 512] |
[512 2048] |
noncont |
fwd |
2394267 |
1838890 |
1.30 |
| EmbeddingBag |
float16 |
[1024 512] |
[256 1024] |
cont |
fwd |
2391292 |
1555550 |
1.54 |
| EmbeddingBag |
float16 |
[1024 512] |
[512 2048] |
cont |
fwd |
4765832 |
2955600 |
1.61 |
| EmbeddingBag |
float16 |
[1024 512] |
[512 2048] |
noncont |
fwd |
4791704 |
3989040 |
1.20 |
| EmbeddingBag |
float16 |
[768 768] |
[256 1024] |
cont |
fwd |
2684508 |
1778770 |
1.51 |
| EmbeddingBag |
float16 |
[768 768] |
[256 1024] |
noncont |
fwd |
2706427 |
2111200 |
1.28 |
| EmbeddingBag |
float16 |
[768 768] |
[512 2048] |
cont |
fwd |
5352231 |
3353980 |
1.60 |
| EmbeddingBag |
float16 |
[768 768] |
[512 2048] |
noncont |
fwd |
5376823 |
4098639 |
1.31 |
| EmbeddingBag |
float16 |
[256 256] |
[512 2048] |
cont |
fwd |
607119 |
422861 |
1.44 |
| EmbeddingBag |
float16 |
[256 256] |
[512 2048] |
noncont |
fwd |
612143 |
508318 |
1.20 |
| EmbeddingBag |
float16 |
[16 255] |
[100 256] |
cont |
fwd |
129472 |
108169 |
1.20 |
float32 (mode max)
| op_name |
dtype |
input size |
weight size |
cont |
direction |
ROCm |
MIOpen |
Improvement |
| EmbeddingBag |
float32 |
[512 512] |
[256 1024] |
cont |
fwd |
1191598 |
863910 |
1.38 |
| EmbeddingBag |
float32 |
[512 512] |
[256 1024] |
noncont |
fwd |
1191582 |
1029700 |
1.16 |
| EmbeddingBag |
float32 |
[512 512] |
[512 2048] |
cont |
fwd |
2364124 |
1556490 |
1.52 |
| EmbeddingBag |
float32 |
[512 512] |
[512 2048] |
noncont |
fwd |
2376844 |
1863040 |
1.28 |
| EmbeddingBag |
float32 |
[1024 512] |
[256 1024] |
cont |
fwd |
2381244 |
1587020 |
1.50 |
| EmbeddingBag |
float32 |
[1024 512] |
[512 2048] |
cont |
fwd |
4769672 |
3037880 |
1.57 |
| EmbeddingBag |
float32 |
[1024 512] |
[512 2048] |
noncont |
fwd |
4799640 |
3896120 |
1.23 |
| EmbeddingBag |
float32 |
[768 768] |
[256 1024] |
cont |
fwd |
2664284 |
1819250 |
1.46 |
| EmbeddingBag |
float32 |
[768 768] |
[256 1024] |
noncont |
fwd |
2684236 |
2146680 |
1.25 |
| EmbeddingBag |
float32 |
[768 768] |
[512 2048] |
cont |
fwd |
5335863 |
3436330 |
1.55 |
| EmbeddingBag |
float32 |
[768 768] |
[512 2048] |
noncont |
fwd |
5371095 |
4152190 |
1.29 |
| EmbeddingBag |
float32 |
[256 256] |
[512 2048] |
cont |
fwd |
602367 |
435573 |
1.38 |
| EmbeddingBag |
float32 |
[256 256] |
[512 2048] |
noncont |
fwd |
607999 |
517136 |
1.18 |
| EmbeddingBag |
float32 |
[128 512] |
[512 2048] |
cont |
fwd |
600879 |
528356 |
1.14 |
bfloat16 (mode max)
| op_name |
dtype |
input size |
weight size |
cont |
direction |
ROCm |
MIOpen |
Improvement |
| EmbeddingBag |
bfloat16 |
[512 512] |
[256 1024] |
cont |
fwd |
1188750 |
834043 |
1.43 |
| EmbeddingBag |
bfloat16 |
[512 512] |
[256 1024] |
noncont |
fwd |
1188286 |
1010610 |
1.18 |
| EmbeddingBag |
bfloat16 |
[512 512] |
[512 2048] |
cont |
fwd |
2360604 |
1491550 |
1.58 |
| EmbeddingBag |
bfloat16 |
[512 512] |
[512 2048] |
noncont |
fwd |
2360636 |
1816100 |
1.30 |
| EmbeddingBag |
bfloat16 |
[1024 512] |
[256 1024] |
cont |
fwd |
2373468 |
1527450 |
1.55 |
| EmbeddingBag |
bfloat16 |
[1024 512] |
[512 2048] |
cont |
fwd |
4721576 |
2911920 |
1.62 |
| EmbeddingBag |
bfloat16 |
[1024 512] |
[512 2048] |
noncont |
fwd |
4735288 |
4048420 |
1.17 |
| EmbeddingBag |
bfloat16 |
[768 768] |
[100 256] |
cont |
fwd |
637519 |
525137 |
1.21 |
| EmbeddingBag |
bfloat16 |
[768 768] |
[256 1024] |
cont |
fwd |
2662716 |
1746410 |
1.52 |
| EmbeddingBag |
bfloat16 |
[768 768] |
[256 1024] |
noncont |
fwd |
2681467 |
2100390 |
1.28 |
| EmbeddingBag |
bfloat16 |
[768 768] |
[512 2048] |
cont |
fwd |
5294503 |
3315640 |
1.60 |
| EmbeddingBag |
bfloat16 |
[768 768] |
[512 2048] |
noncont |
fwd |
5323143 |
4110299 |
1.30 |
| EmbeddingBag |
bfloat16 |
[256 256] |
[512 2048] |
cont |
fwd |
600703 |
416995 |
1.44 |
float32 (mode mean)
| op_name |
dtype |
size |
direction |
cont |
direction |
ROCm |
MIOpen |
Improvement |
| EmbeddingBag |
float32 |
[512 512] |
[256 1024] |
cont |
fwd |
1154688 |
955314 |
1.21 |
| EmbeddingBag |
float32 |
[512 512] |
[512 2048] |
cont |
fwd |
2307201 |
1764120 |
1.31 |
| EmbeddingBag |
float32 |
[1024 512] |
[256 1024] |
cont |
fwd |
2303857 |
1798470 |
1.28 |
| EmbeddingBag |
float32 |
[1024 512] |
[512 2048] |
cont |
fwd |
4597474 |
3431640 |
1.34 |
| EmbeddingBag |
float32 |
[768 768] |
[512 2048] |
cont |
fwd |
2588381 |
2049459 |
1.26 |
| EmbeddingBag |
float32 |
[768 768] |
[512 2048] |
cont |
fwd |
5169386 |
3869800 |
1.34 |
| EmbeddingBag |
float32 |
[256 256] |
[512 2048] |
cont |
fwd |
589464 |
481217 |
1.22 |
| EmbeddingBag |
float32 |
[16 255] |
[512 2048] |
cont |
fwd |
125583 |
114259 |
1.10 |
| EmbeddingBag |
float32 |
[16 255] |
[512 2048] |
cont |
fwd |
142318 |
126775 |
1.12 |
| EmbeddingBag |
float32 |
[128 128] |
[100 256] |
cont |
fwd |
72623 |
65636 |
1.11 |