Enhancement rotary position embedding
- Added RoPE(rotary position embedding) operation and kernel with solver
- Added driver test and gtest for RoPE
- Compared to ROCm pytorch(no fusion), there is a performance improvement
rope float16
| op_name | dtype | N | S | NH | D_KV | model | direction | ROCm pytorch(op time) | MIOpen HIP(op time) | Improvement |
|---|---|---|---|---|---|---|---|---|---|---|
| RoPE | float16 | 4 | 512 | 6 | 64 | T5 | fwd | 187006 | 12977 | 14.41 |
| RoPE | float16 | 8 | 512 | 6 | 64 | T5 | fwd | 212990 | 19715 | 10.80 |
| RoPE | float16 | 16 | 512 | 6 | 64 | T5 | fwd | 260813 | 34257 | 7.61 |
| RoPE | float16 | 32 | 512 | 6 | 64 | T5 | fwd | 389676 | 63999 | 6.09 |
| RoPE | float16 | 4 | 256 | 6 | 64 | T5 | fwd | 164878 | 8497 | 19.40 |
| RoPE | float16 | 4 | 128 | 6 | 64 | T5 | fwd | 175062 | 6328 | 27.66 |
| RoPE | float16 | 4 | 64 | 6 | 64 | T5 | fwd | 164286 | 6044 | 27.18 |
| RoPE | float16 | 4 | 512 | 6 | 128 | T5 | fwd | 210174 | 21155 | 9.93 |
| RoPE | float16 | 4 | 512 | 6 | 256 | T5 | fwd | 259453 | 47359 | 5.48 |
| RoPE | float16 | 4 | 512 | 6 | 512 | T5 | fwd | 421932 | 86630 | 4.87 |
| RoPE | float16 | 4 | 512 | 3 | 384 | T5 | fwd | 230782 | 32266 | 7.15 |
| RoPE | float16 | 8 | 512 | 3 | 384 | T5 | fwd | 329261 | 60123 | 5.48 |
| RoPE | float16 | 16 | 512 | 3 | 384 | T5 | fwd | 545467 | 119803 | 4.55 |
| RoPE | float16 | 32 | 512 | 3 | 384 | T5 | fwd | 978406 | 225011 | 4.35 |
| RoPE | float16 | 32 | 256 | 3 | 384 | T5 | fwd | 530474 | 95572 | 5.55 |
| RoPE | float16 | 32 | 128 | 3 | 384 | T5 | fwd | 324925 | 50221 | 6.47 |
| RoPE | float16 | 32 | 64 | 3 | 384 | T5 | fwd | 235982 | 26915 | 8.77 |
| RoPE | float16 | 32 | 512 | 3 | 192 | T5 | fwd | 532891 | 95608 | 5.57 |
| RoPE | float16 | 32 | 512 | 3 | 768 | T5 | fwd | 2002588 | 500723 | 4.00 |
| op_name | dtype | N | S | NH | D_KV | model | direction | ROCm pytorch(op time) | MIOpen HIP(op time) | Improvement |
|---|---|---|---|---|---|---|---|---|---|---|
| RoPE | float16 | 4 | 512 | 6 | 64 | T5 | bwd | 271005 | 13244 | 20.46 |
| RoPE | float16 | 8 | 512 | 6 | 64 | T5 | bwd | 292909 | 20213 | 14.49 |
| RoPE | float16 | 16 | 512 | 6 | 64 | T5 | bwd | 374060 | 34968 | 10.70 |
| RoPE | float16 | 32 | 512 | 6 | 64 | T5 | bwd | 536522 | 65119 | 8.24 |
| RoPE | float16 | 4 | 256 | 6 | 64 | T5 | bwd | 272813 | 8604 | 31.71 |
| RoPE | float16 | 4 | 128 | 6 | 64 | T5 | bwd | 274173 | 6275 | 43.69 |
| RoPE | float16 | 4 | 64 | 6 | 64 | T5 | bwd | 230638 | 5866 | 39.32 |
| RoPE | float16 | 4 | 512 | 6 | 128 | T5 | bwd | 291709 | 21795 | 13.38 |
| RoPE | float16 | 4 | 512 | 6 | 256 | T5 | bwd | 372716 | 47412 | 7.86 |
| RoPE | float16 | 4 | 512 | 6 | 512 | T5 | bwd | 564348 | 87856 | 6.42 |
| RoPE | float16 | 4 | 512 | 3 | 384 | T5 | bwd | 339036 | 32995 | 10.28 |
| RoPE | float16 | 8 | 512 | 3 | 384 | T5 | bwd | 456619 | 61474 | 7.43 |
| RoPE | float16 | 16 | 512 | 3 | 384 | T5 | bwd | 737305 | 119999 | 6.14 |
| RoPE | float16 | 32 | 512 | 3 | 384 | T5 | bwd | 1295779 | 228389 | 5.67 |
| RoPE | float16 | 32 | 256 | 3 | 384 | T5 | bwd | 757769 | 97332 | 7.79 |
| RoPE | float16 | 32 | 128 | 3 | 384 | T5 | bwd | 547466 | 50968 | 10.74 |
| RoPE | float16 | 32 | 64 | 3 | 384 | T5 | bwd | 491483 | 27182 | 18.08 |
| RoPE | float16 | 32 | 512 | 3 | 192 | T5 | bwd | 756409 | 97261 | 7.78 |
| RoPE | float16 | 32 | 512 | 3 | 768 | T5 | bwd | 2583414 | 504599 | 5.12 |
rope float32
| op_name | dtype | N | S | NH | D_KV | model | direction | ROCm pytorch(op time) | MIOpen HIP(op time) | Improvement |
|---|---|---|---|---|---|---|---|---|---|---|
| RoPE | float32 | 4 | 512 | 6 | 64 | T5 | fwd | 178942 | 11377 | 15.73 |
| RoPE | float32 | 8 | 512 | 6 | 64 | T5 | fwd | 203790 | 17457 | 11.67 |
| RoPE | float32 | 16 | 512 | 6 | 64 | T5 | fwd | 254013 | 30719 | 8.27 |
| RoPE | float32 | 32 | 512 | 6 | 64 | T5 | fwd | 386796 | 57848 | 6.69 |
| RoPE | float32 | 4 | 256 | 6 | 64 | T5 | fwd | 165295 | 7893 | 20.94 |
| RoPE | float32 | 4 | 128 | 6 | 64 | T5 | fwd | 164350 | 5795 | 28.36 |
| RoPE | float32 | 4 | 64 | 6 | 64 | T5 | fwd | 164446 | 5475 | 30.04 |
| RoPE | float32 | 4 | 512 | 6 | 128 | T5 | fwd | 206510 | 17759 | 11.63 |
| RoPE | float32 | 4 | 512 | 6 | 256 | T5 | fwd | 273021 | 32106 | 8.50 |
| RoPE | float32 | 4 | 512 | 6 | 512 | T5 | fwd | 427372 | 63964 | 6.68 |
| RoPE | float32 | 4 | 512 | 3 | 384 | T5 | fwd | 231870 | 24906 | 9.31 |
| RoPE | float32 | 8 | 512 | 3 | 384 | T5 | fwd | 329357 | 45706 | 7.21 |
| RoPE | float32 | 16 | 512 | 3 | 384 | T5 | fwd | 550858 | 85456 | 6.45 |
| RoPE | float32 | 32 | 512 | 3 | 384 | T5 | fwd | 977030 | 167038 | 5.85 |
| RoPE | float32 | 32 | 256 | 3 | 384 | T5 | fwd | 528571 | 85154 | 6.21 |
| RoPE | float32 | 32 | 128 | 3 | 384 | T5 | fwd | 314973 | 45013 | 7.00 |
| RoPE | float32 | 32 | 64 | 3 | 384 | T5 | fwd | 225390 | 24248 | 9.30 |
| RoPE | float32 | 32 | 512 | 3 | 192 | T5 | fwd | 528506 | 84941 | 6.22 |
| RoPE | float32 | 32 | 512 | 3 | 768 | T5 | fwd | 1997996 | 358574 | 5.57 |
| op_name | dtype | N | S | NH | D_KV | model | direction | ROCm pytorch(op time) | MIOpen HIP(op time) | Improvement |
|---|---|---|---|---|---|---|---|---|---|---|
| RoPE | float32 | 4 | 512 | 6 | 64 | T5 | bwd | 264190 | 11768 | 22.45 |
| RoPE | float32 | 8 | 512 | 6 | 64 | T5 | bwd | 302333 | 17973 | 16.82 |
| RoPE | float32 | 16 | 512 | 6 | 64 | T5 | bwd | 372252 | 31661 | 11.76 |
| RoPE | float32 | 32 | 512 | 6 | 64 | T5 | bwd | 531979 | 59626 | 8.92 |
| RoPE | float32 | 4 | 256 | 6 | 64 | T5 | bwd | 259869 | 7999 | 32.49 |
| RoPE | float32 | 4 | 128 | 6 | 64 | T5 | bwd | 255885 | 5920 | 43.22 |
| RoPE | float32 | 4 | 64 | 6 | 64 | T5 | bwd | 230574 | 5475 | 42.11 |
| RoPE | float32 | 4 | 512 | 6 | 128 | T5 | bwd | 290189 | 18275 | 15.88 |
| RoPE | float32 | 4 | 512 | 6 | 256 | T5 | bwd | 375452 | 32817 | 11.44 |
| RoPE | float32 | 4 | 512 | 6 | 512 | T5 | bwd | 562874 | 65972 | 8.53 |
| RoPE | float32 | 4 | 512 | 3 | 384 | T5 | bwd | 335085 | 25688 | 13.04 |
| RoPE | float32 | 8 | 512 | 3 | 384 | T5 | bwd | 463931 | 47128 | 9.84 |
| RoPE | float32 | 16 | 512 | 3 | 384 | T5 | bwd | 738409 | 87732 | 8.42 |
| RoPE | float32 | 32 | 512 | 3 | 384 | T5 | bwd | 1296787 | 172585 | 7.51 |
| RoPE | float32 | 32 | 256 | 3 | 384 | T5 | bwd | 758984 | 87554 | 8.67 |
| RoPE | float32 | 32 | 128 | 3 | 384 | T5 | bwd | 556698 | 46328 | 12.02 |
| RoPE | float32 | 32 | 64 | 3 | 384 | T5 | bwd | 495243 | 24533 | 20.19 |
| RoPE | float32 | 32 | 512 | 3 | 192 | T5 | bwd | 761865 | 87554 | 8.70 |
| RoPE | float32 | 32 | 512 | 3 | 768 | T5 | bwd | 2585142 | 369632 | 6.99 |
rope bfloat16
| op_name | dtype | N | S | NH | D_KV | model | direction | ROCm pytorch(op time) | MIOpen HIP(op time) | Improvement |
|---|---|---|---|---|---|---|---|---|---|---|
| RoPE | bfloat16 | 4 | 512 | 6 | 64 | T5 | fwd | 177566 | 11662 | 15.23 |
| RoPE | bfloat16 | 8 | 512 | 6 | 64 | T5 | fwd | 203870 | 17955 | 11.35 |
| RoPE | bfloat16 | 16 | 512 | 6 | 64 | T5 | fwd | 249373 | 31413 | 7.94 |
| RoPE | bfloat16 | 32 | 512 | 6 | 64 | T5 | fwd | 394076 | 59199 | 6.66 |
| RoPE | bfloat16 | 4 | 256 | 6 | 64 | T5 | fwd | 170190 | 8017 | 21.23 |
| RoPE | bfloat16 | 4 | 128 | 6 | 64 | T5 | fwd | 171646 | 5955 | 28.82 |
| RoPE | bfloat16 | 4 | 64 | 6 | 64 | T5 | fwd | 166862 | 5564 | 29.99 |
| RoPE | bfloat16 | 4 | 512 | 6 | 128 | T5 | fwd | 205342 | 18151 | 11.31 |
| RoPE | bfloat16 | 4 | 512 | 6 | 256 | T5 | fwd | 259933 | 32550 | 7.99 |
| RoPE | bfloat16 | 4 | 512 | 6 | 512 | T5 | fwd | 423948 | 64408 | 6.58 |
| RoPE | bfloat16 | 4 | 512 | 3 | 384 | T5 | fwd | 224830 | 25315 | 8.88 |
| RoPE | bfloat16 | 8 | 512 | 3 | 384 | T5 | fwd | 328749 | 46524 | 7.07 |
| RoPE | bfloat16 | 16 | 512 | 3 | 384 | T5 | fwd | 544490 | 87021 | 6.26 |
| RoPE | bfloat16 | 32 | 512 | 3 | 384 | T5 | fwd | 977654 | 171269 | 5.71 |
| RoPE | bfloat16 | 32 | 256 | 3 | 384 | T5 | fwd | 543323 | 87110 | 6.24 |
| RoPE | bfloat16 | 32 | 128 | 3 | 384 | T5 | fwd | 321101 | 46026 | 6.98 |
| RoPE | bfloat16 | 32 | 64 | 3 | 384 | T5 | fwd | 226014 | 24639 | 9.17 |
| RoPE | bfloat16 | 32 | 512 | 3 | 192 | T5 | fwd | 533611 | 87092 | 6.13 |
| RoPE | bfloat16 | 32 | 512 | 3 | 768 | T5 | fwd | 2001612 | 361241 | 5.54 |
| op_name | dtype | N | S | NH | D_KV | model | direction | ROCm pytorch(op time) | MIOpen HIP(op time) | Improvement |
|---|---|---|---|---|---|---|---|---|---|---|
| RoPE | bfloat16 | 4 | 512 | 6 | 64 | T5 | bwd | 264669 | 11822 | 22.39 |
| RoPE | bfloat16 | 8 | 512 | 6 | 64 | T5 | bwd | 293693 | 18630 | 15.76 |
| RoPE | bfloat16 | 16 | 512 | 6 | 64 | T5 | bwd | 371452 | 32444 | 11.45 |
| RoPE | bfloat16 | 32 | 512 | 6 | 64 | T5 | bwd | 541386 | 61386 | 8.82 |
| RoPE | bfloat16 | 4 | 256 | 6 | 64 | T5 | bwd | 265982 | 8195 | 32.46 |
| RoPE | bfloat16 | 4 | 128 | 6 | 64 | T5 | bwd | 266093 | 5990 | 44.42 |
| RoPE | bfloat16 | 4 | 64 | 6 | 64 | T5 | bwd | 239422 | 5582 | 42.89 |
| RoPE | bfloat16 | 4 | 512 | 6 | 128 | T5 | bwd | 295149 | 18666 | 15.81 |
| RoPE | bfloat16 | 4 | 512 | 6 | 256 | T5 | bwd | 372524 | 33653 | 11.07 |
| RoPE | bfloat16 | 4 | 512 | 6 | 512 | T5 | bwd | 570586 | 66914 | 8.53 |
| RoPE | bfloat16 | 4 | 512 | 3 | 384 | T5 | bwd | 333709 | 26293 | 12.69 |
| RoPE | bfloat16 | 8 | 512 | 3 | 384 | T5 | bwd | 454619 | 48141 | 9.44 |
| RoPE | bfloat16 | 16 | 512 | 3 | 384 | T5 | bwd | 767993 | 90292 | 8.51 |
| RoPE | bfloat16 | 32 | 512 | 3 | 384 | T5 | bwd | 1295059 | 179180 | 7.23 |
| RoPE | bfloat16 | 32 | 256 | 3 | 384 | T5 | bwd | 767672 | 90274 | 8.50 |
| RoPE | bfloat16 | 32 | 128 | 3 | 384 | T5 | bwd | 543321 | 47590 | 11.42 |
| RoPE | bfloat16 | 32 | 64 | 3 | 384 | T5 | bwd | 483483 | 25386 | 19.05 |
| RoPE | bfloat16 | 32 | 512 | 3 | 192 | T5 | bwd | 760616 | 90185 | 8.43 |
| RoPE | bfloat16 | 32 | 512 | 3 | 768 | T5 | bwd | 2585494 | 375623 | 6.88 |
- Average over all cases
| Op | Direction | Type | geomean |
|---|---|---|---|
| RoPE | fwd | float16 | 8.02 |
| RoPE | bwd | float16 | 11.5 |
| RoPE | fwd | float32 | 9.56 |
| RoPE | bwd | float32 | 13.67 |
| RoPE | fwd | bfloat16 | 9.39 |
| RoPE | bwd | bfloat16 | 13.40 |
@iq136boy can I ask you to help review this PR.
@seungmanhan same in this PR, could you help to resolve the conflict? CI is passing so we should be able to merge soon. Thanks!
@seungmanhan @apwojcik similar in this case, the feature is failing on Windows build:
lld-link: error: undefined symbol: enum miopenStatus_t __cdecl miopen::RoPEForward(struct miopen::Handle &, struct miopen::TensorDescriptor const &, void const *, struct miopen::TensorDescriptor const &, void const *, struct miopen::TensorDescriptor const &, void const *, struct miopen::TensorDescriptor const &, void *)
@seungmanhan could you help to resolve the conflicts? We can merge after that and the CI passing again :) Thanks!
@junliume github action passed.