MIOpen icon indicating copy to clipboard operation
MIOpen copied to clipboard

Implement EmbeddingBag

Open cognaiger9 opened this issue 10 months ago • 0 comments

  • Add EmbeddingBag operation with forward kernels.
  • Add driver and gtest for kernels.
  • MIOpen performs better if:
    • Mode: Max
    • Mode: Mean or Sum, when the tensor type is float and all tensors are contiguous and number of elements in the output exceeds 2^19
  • Note:
    • Forward solver only works with 2D tensor

Average improvement over ROCm

type fwd
float16 1.38
float 1.3
bfloat16 1.4

Detail Benchmark

float16 (mode max)
op_name dtype input size weight size cont direction ROCm MIOpen Improvement
EmbeddingBag float16 [32 655] [100 256] cont fwd 320271 253759 1.26
EmbeddingBag float16 [512 512] [256 1024] cont fwd 1198653 846078 1.42
EmbeddingBag float16 [512 512] [256 1024] noncont fwd 1210990 1012089 1.20
EmbeddingBag float16 [512 512] [512 2048] cont fwd 2385852 1520730 1.57
EmbeddingBag float16 [512 512] [512 2048] noncont fwd 2394267 1838890 1.30
EmbeddingBag float16 [1024 512] [256 1024] cont fwd 2391292 1555550 1.54
EmbeddingBag float16 [1024 512] [512 2048] cont fwd 4765832 2955600 1.61
EmbeddingBag float16 [1024 512] [512 2048] noncont fwd 4791704 3989040 1.20
EmbeddingBag float16 [768 768] [256 1024] cont fwd 2684508 1778770 1.51
EmbeddingBag float16 [768 768] [256 1024] noncont fwd 2706427 2111200 1.28
EmbeddingBag float16 [768 768] [512 2048] cont fwd 5352231 3353980 1.60
EmbeddingBag float16 [768 768] [512 2048] noncont fwd 5376823 4098639 1.31
EmbeddingBag float16 [256 256] [512 2048] cont fwd 607119 422861 1.44
EmbeddingBag float16 [256 256] [512 2048] noncont fwd 612143 508318 1.20
EmbeddingBag float16 [16 255] [100 256] cont fwd 129472 108169 1.20
float32 (mode max)
op_name dtype input size weight size cont direction ROCm MIOpen Improvement
EmbeddingBag float32 [512 512] [256 1024] cont fwd 1191598 863910 1.38
EmbeddingBag float32 [512 512] [256 1024] noncont fwd 1191582 1029700 1.16
EmbeddingBag float32 [512 512] [512 2048] cont fwd 2364124 1556490 1.52
EmbeddingBag float32 [512 512] [512 2048] noncont fwd 2376844 1863040 1.28
EmbeddingBag float32 [1024 512] [256 1024] cont fwd 2381244 1587020 1.50
EmbeddingBag float32 [1024 512] [512 2048] cont fwd 4769672 3037880 1.57
EmbeddingBag float32 [1024 512] [512 2048] noncont fwd 4799640 3896120 1.23
EmbeddingBag float32 [768 768] [256 1024] cont fwd 2664284 1819250 1.46
EmbeddingBag float32 [768 768] [256 1024] noncont fwd 2684236 2146680 1.25
EmbeddingBag float32 [768 768] [512 2048] cont fwd 5335863 3436330 1.55
EmbeddingBag float32 [768 768] [512 2048] noncont fwd 5371095 4152190 1.29
EmbeddingBag float32 [256 256] [512 2048] cont fwd 602367 435573 1.38
EmbeddingBag float32 [256 256] [512 2048] noncont fwd 607999 517136 1.18
EmbeddingBag float32 [128 512] [512 2048] cont fwd 600879 528356 1.14
bfloat16 (mode max)
op_name dtype input size weight size cont direction ROCm MIOpen Improvement
EmbeddingBag bfloat16 [512 512] [256 1024] cont fwd 1188750 834043 1.43
EmbeddingBag bfloat16 [512 512] [256 1024] noncont fwd 1188286 1010610 1.18
EmbeddingBag bfloat16 [512 512] [512 2048] cont fwd 2360604 1491550 1.58
EmbeddingBag bfloat16 [512 512] [512 2048] noncont fwd 2360636 1816100 1.30
EmbeddingBag bfloat16 [1024 512] [256 1024] cont fwd 2373468 1527450 1.55
EmbeddingBag bfloat16 [1024 512] [512 2048] cont fwd 4721576 2911920 1.62
EmbeddingBag bfloat16 [1024 512] [512 2048] noncont fwd 4735288 4048420 1.17
EmbeddingBag bfloat16 [768 768] [100 256] cont fwd 637519 525137 1.21
EmbeddingBag bfloat16 [768 768] [256 1024] cont fwd 2662716 1746410 1.52
EmbeddingBag bfloat16 [768 768] [256 1024] noncont fwd 2681467 2100390 1.28
EmbeddingBag bfloat16 [768 768] [512 2048] cont fwd 5294503 3315640 1.60
EmbeddingBag bfloat16 [768 768] [512 2048] noncont fwd 5323143 4110299 1.30
EmbeddingBag bfloat16 [256 256] [512 2048] cont fwd 600703 416995 1.44
float32 (mode mean)
op_name dtype size direction cont direction ROCm MIOpen Improvement
EmbeddingBag float32 [512 512] [256 1024] cont fwd 1154688 955314 1.21
EmbeddingBag float32 [512 512] [512 2048] cont fwd 2307201 1764120 1.31
EmbeddingBag float32 [1024 512] [256 1024] cont fwd 2303857 1798470 1.28
EmbeddingBag float32 [1024 512] [512 2048] cont fwd 4597474 3431640 1.34
EmbeddingBag float32 [768 768] [512 2048] cont fwd 2588381 2049459 1.26
EmbeddingBag float32 [768 768] [512 2048] cont fwd 5169386 3869800 1.34
EmbeddingBag float32 [256 256] [512 2048] cont fwd 589464 481217 1.22
EmbeddingBag float32 [16 255] [512 2048] cont fwd 125583 114259 1.10
EmbeddingBag float32 [16 255] [512 2048] cont fwd 142318 126775 1.12
EmbeddingBag float32 [128 128] [100 256] cont fwd 72623 65636 1.11

cognaiger9 avatar Mar 11 '25 10:03 cognaiger9