MIOpen icon indicating copy to clipboard operation
MIOpen copied to clipboard

Enhancement rotary position embedding

Open seungmanhan opened this issue 1 year ago • 5 comments

  • Added RoPE(rotary position embedding) operation and kernel with solver
  • Added driver test and gtest for RoPE
  • Compared to ROCm pytorch(no fusion), there is a performance improvement
rope float16
op_name dtype N S NH D_KV model direction ROCm pytorch(op time) MIOpen HIP(op time) Improvement
RoPE float16 4 512 6 64 T5 fwd 187006 12977 14.41
RoPE float16 8 512 6 64 T5 fwd 212990 19715 10.80
RoPE float16 16 512 6 64 T5 fwd 260813 34257 7.61
RoPE float16 32 512 6 64 T5 fwd 389676 63999 6.09
RoPE float16 4 256 6 64 T5 fwd 164878 8497 19.40
RoPE float16 4 128 6 64 T5 fwd 175062 6328 27.66
RoPE float16 4 64 6 64 T5 fwd 164286 6044 27.18
RoPE float16 4 512 6 128 T5 fwd 210174 21155 9.93
RoPE float16 4 512 6 256 T5 fwd 259453 47359 5.48
RoPE float16 4 512 6 512 T5 fwd 421932 86630 4.87
RoPE float16 4 512 3 384 T5 fwd 230782 32266 7.15
RoPE float16 8 512 3 384 T5 fwd 329261 60123 5.48
RoPE float16 16 512 3 384 T5 fwd 545467 119803 4.55
RoPE float16 32 512 3 384 T5 fwd 978406 225011 4.35
RoPE float16 32 256 3 384 T5 fwd 530474 95572 5.55
RoPE float16 32 128 3 384 T5 fwd 324925 50221 6.47
RoPE float16 32 64 3 384 T5 fwd 235982 26915 8.77
RoPE float16 32 512 3 192 T5 fwd 532891 95608 5.57
RoPE float16 32 512 3 768 T5 fwd 2002588 500723 4.00
op_name dtype N S NH D_KV model direction ROCm pytorch(op time) MIOpen HIP(op time) Improvement
RoPE float16 4 512 6 64 T5 bwd 271005 13244 20.46
RoPE float16 8 512 6 64 T5 bwd 292909 20213 14.49
RoPE float16 16 512 6 64 T5 bwd 374060 34968 10.70
RoPE float16 32 512 6 64 T5 bwd 536522 65119 8.24
RoPE float16 4 256 6 64 T5 bwd 272813 8604 31.71
RoPE float16 4 128 6 64 T5 bwd 274173 6275 43.69
RoPE float16 4 64 6 64 T5 bwd 230638 5866 39.32
RoPE float16 4 512 6 128 T5 bwd 291709 21795 13.38
RoPE float16 4 512 6 256 T5 bwd 372716 47412 7.86
RoPE float16 4 512 6 512 T5 bwd 564348 87856 6.42
RoPE float16 4 512 3 384 T5 bwd 339036 32995 10.28
RoPE float16 8 512 3 384 T5 bwd 456619 61474 7.43
RoPE float16 16 512 3 384 T5 bwd 737305 119999 6.14
RoPE float16 32 512 3 384 T5 bwd 1295779 228389 5.67
RoPE float16 32 256 3 384 T5 bwd 757769 97332 7.79
RoPE float16 32 128 3 384 T5 bwd 547466 50968 10.74
RoPE float16 32 64 3 384 T5 bwd 491483 27182 18.08
RoPE float16 32 512 3 192 T5 bwd 756409 97261 7.78
RoPE float16 32 512 3 768 T5 bwd 2583414 504599 5.12
rope float32
op_name dtype N S NH D_KV model direction ROCm pytorch(op time) MIOpen HIP(op time) Improvement
RoPE float32 4 512 6 64 T5 fwd 178942 11377 15.73
RoPE float32 8 512 6 64 T5 fwd 203790 17457 11.67
RoPE float32 16 512 6 64 T5 fwd 254013 30719 8.27
RoPE float32 32 512 6 64 T5 fwd 386796 57848 6.69
RoPE float32 4 256 6 64 T5 fwd 165295 7893 20.94
RoPE float32 4 128 6 64 T5 fwd 164350 5795 28.36
RoPE float32 4 64 6 64 T5 fwd 164446 5475 30.04
RoPE float32 4 512 6 128 T5 fwd 206510 17759 11.63
RoPE float32 4 512 6 256 T5 fwd 273021 32106 8.50
RoPE float32 4 512 6 512 T5 fwd 427372 63964 6.68
RoPE float32 4 512 3 384 T5 fwd 231870 24906 9.31
RoPE float32 8 512 3 384 T5 fwd 329357 45706 7.21
RoPE float32 16 512 3 384 T5 fwd 550858 85456 6.45
RoPE float32 32 512 3 384 T5 fwd 977030 167038 5.85
RoPE float32 32 256 3 384 T5 fwd 528571 85154 6.21
RoPE float32 32 128 3 384 T5 fwd 314973 45013 7.00
RoPE float32 32 64 3 384 T5 fwd 225390 24248 9.30
RoPE float32 32 512 3 192 T5 fwd 528506 84941 6.22
RoPE float32 32 512 3 768 T5 fwd 1997996 358574 5.57
op_name dtype N S NH D_KV model direction ROCm pytorch(op time) MIOpen HIP(op time) Improvement
RoPE float32 4 512 6 64 T5 bwd 264190 11768 22.45
RoPE float32 8 512 6 64 T5 bwd 302333 17973 16.82
RoPE float32 16 512 6 64 T5 bwd 372252 31661 11.76
RoPE float32 32 512 6 64 T5 bwd 531979 59626 8.92
RoPE float32 4 256 6 64 T5 bwd 259869 7999 32.49
RoPE float32 4 128 6 64 T5 bwd 255885 5920 43.22
RoPE float32 4 64 6 64 T5 bwd 230574 5475 42.11
RoPE float32 4 512 6 128 T5 bwd 290189 18275 15.88
RoPE float32 4 512 6 256 T5 bwd 375452 32817 11.44
RoPE float32 4 512 6 512 T5 bwd 562874 65972 8.53
RoPE float32 4 512 3 384 T5 bwd 335085 25688 13.04
RoPE float32 8 512 3 384 T5 bwd 463931 47128 9.84
RoPE float32 16 512 3 384 T5 bwd 738409 87732 8.42
RoPE float32 32 512 3 384 T5 bwd 1296787 172585 7.51
RoPE float32 32 256 3 384 T5 bwd 758984 87554 8.67
RoPE float32 32 128 3 384 T5 bwd 556698 46328 12.02
RoPE float32 32 64 3 384 T5 bwd 495243 24533 20.19
RoPE float32 32 512 3 192 T5 bwd 761865 87554 8.70
RoPE float32 32 512 3 768 T5 bwd 2585142 369632 6.99
rope bfloat16
op_name dtype N S NH D_KV model direction ROCm pytorch(op time) MIOpen HIP(op time) Improvement
RoPE bfloat16 4 512 6 64 T5 fwd 177566 11662 15.23
RoPE bfloat16 8 512 6 64 T5 fwd 203870 17955 11.35
RoPE bfloat16 16 512 6 64 T5 fwd 249373 31413 7.94
RoPE bfloat16 32 512 6 64 T5 fwd 394076 59199 6.66
RoPE bfloat16 4 256 6 64 T5 fwd 170190 8017 21.23
RoPE bfloat16 4 128 6 64 T5 fwd 171646 5955 28.82
RoPE bfloat16 4 64 6 64 T5 fwd 166862 5564 29.99
RoPE bfloat16 4 512 6 128 T5 fwd 205342 18151 11.31
RoPE bfloat16 4 512 6 256 T5 fwd 259933 32550 7.99
RoPE bfloat16 4 512 6 512 T5 fwd 423948 64408 6.58
RoPE bfloat16 4 512 3 384 T5 fwd 224830 25315 8.88
RoPE bfloat16 8 512 3 384 T5 fwd 328749 46524 7.07
RoPE bfloat16 16 512 3 384 T5 fwd 544490 87021 6.26
RoPE bfloat16 32 512 3 384 T5 fwd 977654 171269 5.71
RoPE bfloat16 32 256 3 384 T5 fwd 543323 87110 6.24
RoPE bfloat16 32 128 3 384 T5 fwd 321101 46026 6.98
RoPE bfloat16 32 64 3 384 T5 fwd 226014 24639 9.17
RoPE bfloat16 32 512 3 192 T5 fwd 533611 87092 6.13
RoPE bfloat16 32 512 3 768 T5 fwd 2001612 361241 5.54
op_name dtype N S NH D_KV model direction ROCm pytorch(op time) MIOpen HIP(op time) Improvement
RoPE bfloat16 4 512 6 64 T5 bwd 264669 11822 22.39
RoPE bfloat16 8 512 6 64 T5 bwd 293693 18630 15.76
RoPE bfloat16 16 512 6 64 T5 bwd 371452 32444 11.45
RoPE bfloat16 32 512 6 64 T5 bwd 541386 61386 8.82
RoPE bfloat16 4 256 6 64 T5 bwd 265982 8195 32.46
RoPE bfloat16 4 128 6 64 T5 bwd 266093 5990 44.42
RoPE bfloat16 4 64 6 64 T5 bwd 239422 5582 42.89
RoPE bfloat16 4 512 6 128 T5 bwd 295149 18666 15.81
RoPE bfloat16 4 512 6 256 T5 bwd 372524 33653 11.07
RoPE bfloat16 4 512 6 512 T5 bwd 570586 66914 8.53
RoPE bfloat16 4 512 3 384 T5 bwd 333709 26293 12.69
RoPE bfloat16 8 512 3 384 T5 bwd 454619 48141 9.44
RoPE bfloat16 16 512 3 384 T5 bwd 767993 90292 8.51
RoPE bfloat16 32 512 3 384 T5 bwd 1295059 179180 7.23
RoPE bfloat16 32 256 3 384 T5 bwd 767672 90274 8.50
RoPE bfloat16 32 128 3 384 T5 bwd 543321 47590 11.42
RoPE bfloat16 32 64 3 384 T5 bwd 483483 25386 19.05
RoPE bfloat16 32 512 3 192 T5 bwd 760616 90185 8.43
RoPE bfloat16 32 512 3 768 T5 bwd 2585494 375623 6.88
  • Average over all cases
Op Direction Type geomean
RoPE fwd float16 8.02
RoPE bwd float16 11.5
RoPE fwd float32 9.56
RoPE bwd float32 13.67
RoPE fwd bfloat16 9.39
RoPE bwd bfloat16 13.40

seungmanhan avatar May 31 '24 11:05 seungmanhan

@iq136boy can I ask you to help review this PR.

JehandadKhan avatar Jun 05 '24 01:06 JehandadKhan

@seungmanhan same in this PR, could you help to resolve the conflict? CI is passing so we should be able to merge soon. Thanks!

junliume avatar Jun 28 '24 00:06 junliume

@seungmanhan @apwojcik similar in this case, the feature is failing on Windows build:

lld-link: error: undefined symbol: enum miopenStatus_t __cdecl miopen::RoPEForward(struct miopen::Handle &, struct miopen::TensorDescriptor const &, void const *, struct miopen::TensorDescriptor const &, void const *, struct miopen::TensorDescriptor const &, void const *, struct miopen::TensorDescriptor const &, void *)

junliume avatar Jul 24 '24 00:07 junliume

@seungmanhan could you help to resolve the conflicts? We can merge after that and the CI passing again :) Thanks!

junliume avatar Jul 24 '24 17:07 junliume

@junliume github action passed.

seungmanhan avatar Jul 31 '24 12:07 seungmanhan