MIOpen icon indicating copy to clipboard operation
MIOpen copied to clipboard

Implement embedding

Open cognaiger9 opened this issue 10 months ago • 0 comments

  • Add Embedding operation with backward kernels.
  • Add driver and gtest for kernels.
  • MIOpen performs better if:
    • Split dimension is 0
    • Number of elements in input tensor are less than 400,000 (forward case) or less than 800,000 (backward case)

Average improvement over ROCm

type bwd
float16 2.1
float 2.75
bfloat16 2.41

Detail Benchmark

float16
op_name dtype input size weight size contiguous direction ROCm MIOpen MIOpen vs ROCm
Embedding float16 [40 512] [1026 1024] cont bwd 600548 423901 1.42
Embedding float16 [40 512] [1026 1024] noncont bwd 579166 682325 0.85
Embedding float16 [40 512] [50000 768] cont bwd 893821 466886 1.91
Embedding float16 [40 512] [50000 768] noncont bwd 916582 695751 1.32
Embedding float16 [40 512] [2 768] cont bwd 897141 180925 4.96
Embedding float16 [40 512] [2 768] noncont bwd 931511 488879 1.91
Embedding float16 [48 512] [52100 512] cont bwd 798757 356038 2.24
Embedding float16 [48 512] [52100 512] noncont bwd 886080 540652 1.64
Embedding float16 [32 512] [514 768] cont bwd 430180 274782 1.57
Embedding float16 [32 512] [514 768] noncont bwd 511610 420868 1.22
Embedding float16 [32 188] [148 512] cont bwd 434340 140962 3.08
Embedding float16 [32 188] [148 512] noncont bwd 392898 138703 2.83
Embedding float16 [16 512] [512 768] cont bwd 373517 150829 2.48
Embedding float16 [16 512] [512 768] noncont bwd 486602 225576 2.16
Embedding float16 [16 512] [52000 768] cont bwd 672631 317764 2.12
Embedding float16 [16 512] [52000 768] noncont bwd 650389 408122 1.59
Embedding float16 [16 1024] [52000 768] cont bwd 753374 421790 1.79
Embedding float16 [16 1024] [52000 768] noncont bwd 848298 603588 1.41
Embedding float16 [2 1024] [1024 768] cont bwd 172588 57119 3.02
Embedding float16 [2 1024] [1024 768] noncont bwd 194449 77961 2.49
float32
op_name dtype input size weight size contiguous direction ROCm MIOpen MIOpen vs ROCm
Embedding float32 [40 512] [1026 1024] cont bwd 515964 260471 1.98
Embedding float32 [40 512] [1026 1024] noncont bwd 587787 636332 0.92
Embedding float32 [40 512] [50000 768] cont bwd 945843 414779 2.28
Embedding float32 [40 512] [50000 768] noncont bwd 1030907 662515 1.56
Embedding float32 [40 512] [2 768] cont bwd 849659 182595 4.65
Embedding float32 [40 512] [2 768] noncont bwd 902241 493036 1.83
Embedding float32 [48 512] [52100 512] cont bwd 804137 314753 2.55
Embedding float32 [48 512] [52100 512] noncont bwd 969344 514590 1.88
Embedding float32 [32 512] [514 768] cont bwd 461401 159905 2.89
Embedding float32 [32 512] [514 768] noncont bwd 452040 391857 1.15
Embedding float32 [32 188] [148 512] cont bwd 361496 65447 5.52
Embedding float32 [32 188] [148 512] noncont bwd 889040 150777 5.90
Embedding float32 [16 512] [512 768] cont bwd 449706 93224 4.82
Embedding float32 [16 512] [512 768] noncont bwd 479844 210369 2.28
Embedding float32 [16 512] [52000 768] cont bwd 712272 296758 2.40
Embedding float32 [16 512] [52000 768] noncont bwd 838698 395938 2.12
Embedding float32 [16 1024] [52000 768] cont bwd 802416 381329 2.10
Embedding float32 [16 1024] [52000 768] noncont bwd 878280 577851 1.52
Embedding float32 [2 1024] [1024 768] cont bwd 177068 45560 3.89
Embedding float32 [2 1024] [1024 768] noncont bwd 210750 75880 2.78
bfloat16
op_name dtype input size weight size contiguous direction ROCm MIOpen MIOpen vs ROCm
Embedding bfloat16 [40 512] [1026 1024] cont bwd 469609 319702 1.47
Embedding bfloat16 [40 512] [1026 1024] noncont bwd 574466 664567 0.86
Embedding bfloat16 [40 512] [50000 768] cont bwd 924102 427295 2.16
Embedding bfloat16 [40 512] [50000 768] noncont bwd 970384 683101 1.42
Embedding bfloat16 [40 512] [2 768] cont bwd 845138 178876 4.72
Embedding bfloat16 [40 512] [2 768] noncont bwd 879140 492015 1.79
Embedding bfloat16 [48 512] [52100 512] cont bwd 802816 324140 2.48
Embedding bfloat16 [48 512] [52100 512] noncont bwd 965304 530431 1.82
Embedding bfloat16 [32 512] [514 768] cont bwd 381746 204544 1.87
Embedding bfloat16 [32 512] [514 768] noncont bwd 1047968 411612 2.55
Embedding bfloat16 [32 188] [148 512] cont bwd 370737 72916 5.08
Embedding bfloat16 [32 188] [148 512] noncont bwd 376197 122161 3.08
Embedding bfloat16 [16 512] [512 768] cont bwd 365477 115845 3.15
Embedding bfloat16 [16 512] [512 768] noncont bwd 351416 218495 1.61
Embedding bfloat16 [16 512] [52000 768] cont bwd 688691 301789 2.28
Embedding bfloat16 [16 512] [52000 768] noncont bwd 773815 402798 1.92
Embedding bfloat16 [16 1024] [52000 768] cont bwd 805536 389987 2.07
Embedding bfloat16 [16 1024] [52000 768] noncont bwd 818937 594882 1.38
Embedding bfloat16 [2 1024] [1024 768] cont bwd 185588 49561 3.74
Embedding bfloat16 [2 1024] [1024 768] noncont bwd 215590 75879 2.84

cognaiger9 avatar Feb 20 '25 03:02 cognaiger9