Impl Kthvalue operation
- [x] Added Kthvalue operation with forward kernels.
- [x] Added driver test and gtest.
- [x] Compared with ROCm.
Compare to ROCm
The kernel is only 20% faster than ROCm if the following constraints are applied:
- tensor dim num >= 2.
- selected dim size >= 300.
- selected dim has stride = 1.
| type | Forward |
|---|---|
| float32 | 3.13 |
| float16 | 3.52 |
Bfloat16 is currently not supported, so this is a always winning case.
Detail benchmark
Float32
| dtype | size | dim | k | contiguous | direction | rocm_op_avg | kernel_duration | improvement |
|---|---|---|---|---|---|---|---|---|
| float32 | [250 512] | 1 | 12 | TRUE | fwd | 43059 | 41215 | 1.044740992 |
| float32 | [250 1024] | 1 | 999 | TRUE | fwd | 88646 | 48700 | 1.820246407 |
| float32 | [250 2048] | 1 | 89 | TRUE | fwd | 95638 | 78909 | 1.2120037 |
| float32 | [250 4096] | 1 | 15 | TRUE | fwd | 135641 | 117475 | 1.154637157 |
| float32 | [250 8192] | 1 | 8170 | TRUE | fwd | 194749 | 176968 | 1.100475792 |
| float32 | [250 16384] | 1 | 12345 | TRUE | fwd | 417932 | 345543 | 1.209493464 |
| float32 | [250 32768] | 1 | 3 | TRUE | fwd | 596200 | 687708 | 0.866937712 |
| float32 | [500 512] | 1 | 12 | TRUE | fwd | 63252 | 49287 | 1.283340435 |
| float32 | [500 1024] | 1 | 999 | TRUE | fwd | 147242 | 62657 | 2.349968878 |
| float32 | [500 2048] | 1 | 89 | TRUE | fwd | 160347 | 89168 | 1.798257222 |
| float32 | [500 4096] | 1 | 15 | TRUE | fwd | 225871 | 164361 | 1.374237197 |
| float32 | [500 8192] | 1 | 8170 | TRUE | fwd | 328646 | 294282 | 1.116772348 |
| float32 | [500 16384] | 1 | 12345 | TRUE | fwd | 719136 | 575904 | 1.248708118 |
| float32 | [500 32768] | 1 | 3 | TRUE | fwd | 1053798 | 1203350 | 0.875720281 |
| float32 | [1000 512] | 1 | 12 | TRUE | fwd | 106807 | 83762 | 1.275124758 |
| float32 | [1000 1024] | 1 | 999 | TRUE | fwd | 275938 | 105188 | 2.623284025 |
| float32 | [1000 2048] | 1 | 89 | TRUE | fwd | 304612 | 149585 | 2.036380653 |
| float32 | [1000 4096] | 1 | 15 | TRUE | fwd | 428140 | 271237 | 1.578471964 |
| float32 | [1000 8192] | 1 | 8170 | TRUE | fwd | 623817 | 548661 | 1.136980759 |
| float32 | [1000 16384] | 1 | 12345 | TRUE | fwd | 1396556 | 1010450 | 1.38211292 |
| float32 | [1000 32768] | 1 | 3 | TRUE | fwd | 2035702 | 2144470 | 0.949279775 |
| float32 | [2000 512] | 1 | 12 | TRUE | fwd | 193501 | 135947 | 1.423356161 |
| float32 | [2000 1024] | 1 | 999 | TRUE | fwd | 532867 | 169996 | 3.13458552 |
| float32 | [2000 2048] | 1 | 89 | TRUE | fwd | 584086 | 255855 | 2.282878974 |
| float32 | [2000 4096] | 1 | 15 | TRUE | fwd | 828038 | 470124 | 1.76131829 |
| float32 | [2000 8192] | 1 | 8170 | TRUE | fwd | 1212079 | 1032840 | 1.173539948 |
| float32 | [2000 16384] | 1 | 12345 | TRUE | fwd | 2750676 | 1902170 | 1.446072643 |
| float32 | [2000 32768] | 1 | 3 | TRUE | fwd | 4002068 | 3939170 | 1.015967323 |
| float32 | [512 25 5 2] | 0 | 12 | FALSE | fwd | 70594 | 42062 | 1.678331986 |
| float32 | [1024 25 5 2] | 0 | 999 | FALSE | fwd | 149429 | 48639 | 3.072205432 |
| float32 | [2048 25 5 2] | 0 | 89 | FALSE | fwd | 236871 | 79448 | 2.981459571 |
| float32 | [4096 25 5 2] | 0 | 15 | FALSE | fwd | 395756 | 122578 | 3.228605459 |
| float32 | [8192 25 5 2] | 0 | 8170 | FALSE | fwd | 738614 | 176604 | 4.182317501 |
| float32 | [16384 25 5 2] | 0 | 12345 | FALSE | fwd | 1629474 | 336355 | 4.844506548 |
| float32 | [32768 25 5 2] | 0 | 3 | FALSE | fwd | 2670465 | 703696 | 3.79491286 |
| float32 | [512 10 5 10] | 0 | 12 | FALSE | fwd | 112596 | 49261 | 2.285702686 |
| float32 | [1024 10 5 10] | 0 | 999 | FALSE | fwd | 252072 | 62701 | 4.020222963 |
| float32 | [2048 10 5 10] | 0 | 89 | FALSE | fwd | 417629 | 90062 | 4.637127756 |
| float32 | [4096 10 5 10] | 0 | 15 | FALSE | fwd | 722886 | 159680 | 4.527091683 |
| float32 | [8192 10 5 10] | 0 | 8170 | FALSE | fwd | 1516462 | 309848 | 4.894212646 |
| float32 | [16384 10 5 10] | 0 | 12345 | FALSE | fwd | 3709584 | 578470 | 6.412750877 |
| float32 | [32768 10 5 10] | 0 | 3 | FALSE | fwd | 6020662 | 1242630 | 4.845096288 |
| float32 | [512 4 10 25] | 0 | 12 | FALSE | fwd | 203270 | 79946 | 2.542591249 |
| float32 | [1024 4 10 25] | 0 | 999 | FALSE | fwd | 467342 | 102187 | 4.573399748 |
| float32 | [2048 4 10 25] | 0 | 89 | FALSE | fwd | 774664 | 149617 | 5.177646925 |
| float32 | [4096 4 10 25] | 0 | 15 | FALSE | fwd | 1520254 | 271821 | 5.592849706 |
| float32 | [8192 4 10 25] | 0 | 8170 | FALSE | fwd | 3187792 | 548248 | 5.814507303 |
| float32 | [16384 4 10 25] | 0 | 12345 | FALSE | fwd | 7293628 | 1024119 | 7.121855956 |
| float32 | [32768 4 10 25] | 0 | 3 | FALSE | fwd | 12884453 | 2152330 | 5.986281379 |
| float32 | [512 5 4 100] | 0 | 12 | FALSE | fwd | 355643 | 134525 | 2.643694481 |
| float32 | [1024 5 4 100] | 0 | 999 | FALSE | fwd | 897099 | 169670 | 5.287316556 |
| float32 | [2048 5 4 100] | 0 | 89 | FALSE | fwd | 1496765 | 253687 | 5.90004612 |
| float32 | [4096 5 4 100] | 0 | 15 | FALSE | fwd | 2821093 | 473393 | 5.959304426 |
| float32 | [8192 5 4 100] | 0 | 8170 | FALSE | fwd | 6715658 | 1052180 | 6.382613241 |
| float32 | [16384 5 4 100] | 0 | 12345 | FALSE | fwd | 15898974 | 1963890 | 8.095654034 |
| float32 | [32768 5 4 100] | 0 | 3 | FALSE | fwd | 28610523 | 3981670 | 7.185558572 |
Float16
| dtype | size | dim | k | contiguous | direction | rocm_op_avg | kernel_duration | improvement |
|---|---|---|---|---|---|---|---|---|
| float16 | [250 512] | 1 | 12 | TRUE | fwd | 34946 | 31275 | 1.117378098 |
| float16 | [250 1024] | 1 | 999 | TRUE | fwd | 70389 | 37801 | 1.862093595 |
| float16 | [250 2048] | 1 | 89 | TRUE | fwd | 78277 | 51474 | 1.520709484 |
| float16 | [250 4096] | 1 | 15 | TRUE | fwd | 103671 | 84190 | 1.231393277 |
| float16 | [250 8192] | 1 | 8170 | TRUE | fwd | 154970 | 119235 | 1.299702269 |
| float16 | [250 16384] | 1 | 12345 | TRUE | fwd | 267890 | 220263 | 1.216227873 |
| float16 | [250 32768] | 1 | 3 | TRUE | fwd | 434925 | 437503 | 0.994107469 |
| float16 | [500 512] | 1 | 12 | TRUE | fwd | 50355 | 38831 | 1.296773197 |
| float16 | [500 1024] | 1 | 999 | TRUE | fwd | 116920 | 46424 | 2.518524901 |
| float16 | [500 2048] | 1 | 89 | TRUE | fwd | 129737 | 67014 | 1.935968604 |
| float16 | [500 4096] | 1 | 15 | TRUE | fwd | 173164 | 108424 | 1.597100273 |
| float16 | [500 8192] | 1 | 8170 | TRUE | fwd | 259633 | 194160 | 1.337211578 |
| float16 | [500 16384] | 1 | 12345 | TRUE | fwd | 442894 | 366487 | 1.208484885 |
| float16 | [500 32768] | 1 | 3 | TRUE | fwd | 760306 | 759536 | 1.001013777 |
| float16 | [1000 512] | 1 | 12 | TRUE | fwd | 86982 | 66693 | 1.304214835 |
| float16 | [1000 1024] | 1 | 999 | TRUE | fwd | 220927 | 79815 | 2.767988473 |
| float16 | [1000 2048] | 1 | 89 | TRUE | fwd | 247520 | 113278 | 2.185066827 |
| float16 | [1000 4096] | 1 | 15 | TRUE | fwd | 329670 | 185945 | 1.772943612 |
| float16 | [1000 8192] | 1 | 8170 | TRUE | fwd | 502305 | 363249 | 1.38281179 |
| float16 | [1000 16384] | 1 | 12345 | TRUE | fwd | 868745 | 691489 | 1.25633958 |
| float16 | [1000 32768] | 1 | 3 | TRUE | fwd | 1470976 | 1399720 | 1.050907324 |
| float16 | [2000 512] | 1 | 12 | TRUE | fwd | 155258 | 110557 | 1.404325371 |
| float16 | [2000 1024] | 1 | 999 | TRUE | fwd | 424860 | 135271 | 3.140806233 |
| float16 | [2000 2048] | 1 | 89 | TRUE | fwd | 480448 | 200435 | 2.397026467 |
| float16 | [2000 4096] | 1 | 15 | TRUE | fwd | 638842 | 339351 | 1.882540496 |
| float16 | [2000 8192] | 1 | 8170 | TRUE | fwd | 981696 | 698615 | 1.405203152 |
| float16 | [2000 16384] | 1 | 12345 | TRUE | fwd | 1725984 | 1332490 | 1.295307282 |
| float16 | [2000 32768] | 1 | 3 | TRUE | fwd | 2871306 | 2663270 | 1.078112996 |
| float16 | [512 25 5 2] | 0 | 12 | FALSE | fwd | 60130 | 31466 | 1.910951503 |
| float16 | [1024 25 5 2] | 0 | 999 | FALSE | fwd | 125796 | 38026 | 3.308157576 |
| float16 | [2048 25 5 2] | 0 | 89 | FALSE | fwd | 197094 | 51839 | 3.802040934 |
| float16 | [4096 25 5 2] | 0 | 15 | FALSE | fwd | 331034 | 84426 | 3.920995902 |
| float16 | [8192 25 5 2] | 0 | 8170 | FALSE | fwd | 613730 | 119307 | 5.144123983 |
| float16 | [16384 25 5 2] | 0 | 12345 | FALSE | fwd | 1200340 | 220390 | 5.446435864 |
| float16 | [32768 25 5 2] | 0 | 3 | FALSE | fwd | 2112320 | 437421 | 4.829031985 |
| float16 | [512 10 5 10] | 0 | 12 | FALSE | fwd | 95939 | 38755 | 2.475525739 |
| float16 | [1024 10 5 10] | 0 | 999 | FALSE | fwd | 211718 | 46826 | 4.521377013 |
| float16 | [2048 10 5 10] | 0 | 89 | FALSE | fwd | 344139 | 67306 | 5.113050842 |
| float16 | [4096 10 5 10] | 0 | 15 | FALSE | fwd | 608882 | 108302 | 5.622075308 |
| float16 | [8192 10 5 10] | 0 | 8170 | FALSE | fwd | 1165923 | 194311 | 6.000293344 |
| float16 | [16384 10 5 10] | 0 | 12345 | FALSE | fwd | 2467419 | 368212 | 6.701082529 |
| float16 | [32768 10 5 10] | 0 | 3 | FALSE | fwd | 4267409 | 765491 | 5.574734386 |
| float16 | [512 4 10 25] | 0 | 12 | FALSE | fwd | 168309 | 66186 | 2.542969812 |
| float16 | [1024 4 10 25] | 0 | 999 | FALSE | fwd | 397452 | 79786 | 4.981475447 |
| float16 | [2048 4 10 25] | 0 | 89 | FALSE | fwd | 653876 | 112213 | 5.827096682 |
| float16 | [4096 4 10 25] | 0 | 15 | FALSE | fwd | 1165075 | 186062 | 6.261756834 |
| float16 | [8192 4 10 25] | 0 | 8170 | FALSE | fwd | 2581518 | 364141 | 7.089336274 |
| float16 | [16384 4 10 25] | 0 | 12345 | FALSE | fwd | 5332721 | 693385 | 7.690851403 |
| float16 | [32768 4 10 25] | 0 | 3 | FALSE | fwd | 9638243 | 1409470 | 6.838203722 |
| float16 | [512 5 4 100] | 0 | 12 | FALSE | fwd | 304617 | 110422 | 2.758662223 |
| float16 | [1024 5 4 100] | 0 | 999 | FALSE | fwd | 751799 | 134641 | 5.583730067 |
| float16 | [2048 5 4 100] | 0 | 89 | FALSE | fwd | 1253494 | 200503 | 6.251746857 |
| float16 | [4096 5 4 100] | 0 | 15 | FALSE | fwd | 2747699 | 339268 | 8.098904111 |
| float16 | [8192 5 4 100] | 0 | 8170 | FALSE | fwd | 5220109 | 699678 | 7.460730507 |
| float16 | [16384 5 4 100] | 0 | 12345 | FALSE | fwd | 12102028 | 1357190 | 8.916974042 |
| float16 | [32768 5 4 100] | 0 | 3 | FALSE | fwd | 23055923 | 2701390 | 8.53483688 |
BFloat16
| size | dim | k | contiguous | direction | kernel_duration |
|---|---|---|---|---|---|
| [250 512] | 1 | 12 | TRUE | fwd | 30880 |
| [250 1024] | 1 | 999 | TRUE | fwd | 36870 |
| [250 2048] | 1 | 89 | TRUE | fwd | 50453 |
| [250 4096] | 1 | 15 | TRUE | fwd | 82257 |
| [250 8192] | 1 | 8170 | TRUE | fwd | 118808 |
| [250 16384] | 1 | 12345 | TRUE | fwd | 225492 |
| [250 32768] | 1 | 3 | TRUE | fwd | 426415 |
| [500 512] | 1 | 12 | TRUE | fwd | 38897 |
| [500 1024] | 1 | 999 | TRUE | fwd | 45475 |
| [500 2048] | 1 | 89 | TRUE | fwd | 66061 |
| [500 4096] | 1 | 15 | TRUE | fwd | 107910 |
| [500 8192] | 1 | 8170 | TRUE | fwd | 194737 |
| [500 16384] | 1 | 12345 | TRUE | fwd | 375429 |
| [500 32768] | 1 | 3 | TRUE | fwd | 742537 |
| [1000 512] | 1 | 12 | TRUE | fwd | 68301 |
| [1000 1024] | 1 | 999 | TRUE | fwd | 76888 |
| [1000 2048] | 1 | 89 | TRUE | fwd | 110524 |
| [1000 4096] | 1 | 15 | TRUE | fwd | 185226 |
| [1000 8192] | 1 | 8170 | TRUE | fwd | 360193 |
| [1000 16384] | 1 | 12345 | TRUE | fwd | 706929 |
| [1000 32768] | 1 | 3 | TRUE | fwd | 1360130 |
| [2000 512] | 1 | 12 | TRUE | fwd | 114506 |
| [2000 1024] | 1 | 999 | TRUE | fwd | 133528 |
| [2000 2048] | 1 | 89 | TRUE | fwd | 200621 |
| [2000 4096] | 1 | 15 | TRUE | fwd | 341171 |
| [2000 8192] | 1 | 8170 | TRUE | fwd | 685400 |
| [2000 16384] | 1 | 12345 | TRUE | fwd | 1365290 |
| [2000 32768] | 1 | 3 | TRUE | fwd | 2585960 |
| [512 25 5 2] | 0 | 12 | FALSE | fwd | 30862 |
| [1024 25 5 2] | 0 | 999 | FALSE | fwd | 36871 |
| [2048 25 5 2] | 0 | 89 | FALSE | fwd | 50507 |
| [4096 25 5 2] | 0 | 15 | FALSE | fwd | 82347 |
| [8192 25 5 2] | 0 | 8170 | FALSE | fwd | 119255 |
| [16384 25 5 2] | 0 | 12345 | FALSE | fwd | 225301 |
| [32768 25 5 2] | 0 | 3 | FALSE | fwd | 426921 |
| [512 10 5 10] | 0 | 12 | FALSE | fwd | 39164 |
| [1024 10 5 10] | 0 | 999 | FALSE | fwd | 45795 |
| [2048 10 5 10] | 0 | 89 | FALSE | fwd | 66400 |
| [4096 10 5 10] | 0 | 15 | FALSE | fwd | 107664 |
| [8192 10 5 10] | 0 | 8170 | FALSE | fwd | 193638 |
| [16384 10 5 10] | 0 | 12345 | FALSE | fwd | 376858 |
| [32768 10 5 10] | 0 | 3 | FALSE | fwd | 747405 |
| [512 4 10 25] | 0 | 12 | FALSE | fwd | 68089 |
| [1024 4 10 25] | 0 | 999 | FALSE | fwd | 77014 |
| [2048 4 10 25] | 0 | 89 | FALSE | fwd | 110597 |
| [4096 4 10 25] | 0 | 15 | FALSE | fwd | 185211 |
| [8192 4 10 25] | 0 | 8170 | FALSE | fwd | 1993090 |
| [16384 4 10 25] | 0 | 12345 | FALSE | fwd | 713662 |
| [32768 4 10 25] | 0 | 3 | FALSE | fwd | 1373720 |
| [512 5 4 100] | 0 | 12 | FALSE | fwd | 114330 |
| [1024 5 4 100] | 0 | 999 | FALSE | fwd | 133850 |
| [2048 5 4 100] | 0 | 89 | FALSE | fwd | 200358 |
| [4096 5 4 100] | 0 | 15 | FALSE | fwd | 340840 |
| [8192 5 4 100] | 0 | 8170 | FALSE | fwd | 687581 |
| [16384 5 4 100] | 0 | 12345 | FALSE | fwd | 1392710 |
| [32768 5 4 100] | 0 | 3 | FALSE | fwd | 2635730 |
The conflicts should be solved before merging of this PR.
The conflicts should be solved before merging of this PR.
Conflicts are resolved. Please re-review my PR.
May I ask about the version of Troch and graphics card used for this test?
May I ask about the version of Troch and graphics card used for this test?
Same as #3152, I used MI250 GPU and Pytorch from rocm/pytorch (image tag: rocm6.1.2_ubuntu22.04_py3.10_pytorch_release-2.1.2).