[FEA] Support for half-float mixed precise in brute-force
- distance supports half-float mixed precision
- prefiltered_brute_force supports half
- migrate the ann brute force test cases and support half
Hey @cjnolet @benfred , here is the performance result(code link), one line of float following one for half (the performance is improved significantly for IO workload reducing):
A100 with 80GB PCIE x 1 @computlab
Type Queries Vectors Dim K Metric Layout Build Time (ms) Search Time (ms) Total Time (ms) Throughput (q/s)
---------------------------------------------------------------------------------------------------------------------------------------------------------------------
float 10 1000000 32 128 InnerProduct row 0.004 1.684 1.688 5924.023
half 10 1000000 32 128 InnerProduct row 0.003 1.563 1.566 6385.688
float 10 1000000 32 128 InnerProduct col 0.392 1.673 2.065 4843.640
half 10 1000000 32 128 InnerProduct col 1.098 2.362 3.459 2890.637
float 10 1000000 32 128 L2SqrtExpanded row 1.021 2.555 3.576 2796.354
half 10 1000000 32 128 L2SqrtExpanded row 0.771 1.613 2.384 4194.696
float 10 1000000 32 128 L2SqrtExpanded col 0.696 1.953 2.649 3775.518
half 10 1000000 32 128 L2SqrtExpanded col 2.082 3.600 5.682 1759.904
float 10 1000000 32 1024 InnerProduct row 0.003 3.867 3.870 2584.163
half 10 1000000 32 1024 InnerProduct row 0.003 3.772 3.774 2649.537
float 10 1000000 32 1024 InnerProduct col 0.380 3.903 4.283 2334.875
half 10 1000000 32 1024 InnerProduct col 1.096 4.575 5.671 1763.367
float 10 1000000 32 1024 L2SqrtExpanded row 0.222 5.674 5.895 1696.235
half 10 1000000 32 1024 L2SqrtExpanded row 0.192 5.585 5.777 1731.000
float 10 1000000 32 1024 L2SqrtExpanded col 0.579 5.914 6.493 1540.118
half 10 1000000 32 1024 L2SqrtExpanded col 0.469 5.789 6.258 1597.850
float 10 1000000 256 128 InnerProduct row 0.003 2.064 2.067 4838.767
half 10 1000000 256 128 InnerProduct row 0.003 2.208 2.211 4522.343
float 10 1000000 256 128 InnerProduct col 2.480 2.561 5.041 1983.605
half 10 1000000 256 128 InnerProduct col 1.708 1.573 3.280 3048.330
float 10 1000000 256 128 L2SqrtExpanded row 1.578 2.404 3.982 2511.399
half 10 1000000 256 128 L2SqrtExpanded row 1.253 2.022 3.275 3053.356
float 10 1000000 256 128 L2SqrtExpanded col 3.470 4.402 7.872 1270.285
half 10 1000000 256 128 L2SqrtExpanded col 2.715 3.080 5.796 1725.468
float 10 1000000 256 1024 InnerProduct row 0.003 6.421 6.424 1556.748
half 10 1000000 256 1024 InnerProduct row 0.004 4.355 4.358 2294.375
float 10 1000000 256 1024 InnerProduct col 2.483 6.152 8.635 1158.097
half 10 1000000 256 1024 InnerProduct col 1.706 5.625 7.331 1364.046
float 10 1000000 256 1024 L2SqrtExpanded row 1.582 9.255 10.838 922.716
half 10 1000000 256 1024 L2SqrtExpanded row 1.253 6.718 7.971 1254.565
float 10 1000000 256 1024 L2SqrtExpanded col 3.467 10.577 14.044 712.051
half 10 1000000 256 1024 L2SqrtExpanded col 2.731 10.204 12.935 773.093
float 10 1000000 1024 128 InnerProduct row 0.003 5.929 5.932 1685.793
half 10 1000000 1024 128 InnerProduct row 0.003 3.225 3.228 3097.965
float 10 1000000 1024 128 InnerProduct col 6.648 5.696 12.344 810.101
half 10 1000000 1024 128 InnerProduct col 4.121 3.334 7.455 1341.401
float 10 1000000 1024 128 L2SqrtExpanded row 3.539 5.533 9.072 1102.320
half 10 1000000 1024 128 L2SqrtExpanded row 2.331 3.022 5.354 1867.913
float 10 1000000 1024 128 L2SqrtExpanded col 9.609 9.480 19.088 523.883
half 10 1000000 1024 128 L2SqrtExpanded col 8.514 5.665 14.179 705.248
float 10 1000000 1024 1024 InnerProduct row 0.003 7.793 7.796 1282.688
half 10 1000000 1024 1024 InnerProduct row 0.003 6.463 6.466 1546.549
float 10 1000000 1024 1024 InnerProduct col 6.635 7.572 14.207 703.864
half 10 1000000 1024 1024 InnerProduct col 4.138 6.572 10.710 933.723
float 10 1000000 1024 1024 L2SqrtExpanded row 3.546 10.093 13.639 733.206
half 10 1000000 1024 1024 L2SqrtExpanded row 2.325 9.668 11.993 833.810
float 10 1000000 1024 1024 L2SqrtExpanded col 9.578 14.069 23.647 422.885
half 10 1000000 1024 1024 L2SqrtExpanded col 8.411 12.422 20.833 479.998
float 100 1000000 32 128 InnerProduct row 0.003 9.865 9.869 10133.094
half 100 1000000 32 128 InnerProduct row 0.003 9.882 9.885 10116.119
float 100 1000000 32 128 InnerProduct col 1.262 9.685 10.947 9134.877
half 100 1000000 32 128 InnerProduct col 1.095 9.843 10.938 9142.326
float 100 1000000 32 128 L2SqrtExpanded row 1.013 9.841 10.854 9213.248
half 100 1000000 32 128 L2SqrtExpanded row 0.985 9.875 10.860 9207.921
float 100 1000000 32 128 L2SqrtExpanded col 2.233 10.793 13.026 7677.038
half 100 1000000 32 128 L2SqrtExpanded col 2.075 10.968 13.044 7666.443
float 100 1000000 32 1024 InnerProduct row 0.003 31.016 31.019 3223.793
half 100 1000000 32 1024 InnerProduct row 0.003 30.987 30.990 3226.864
float 100 1000000 32 1024 InnerProduct col 1.257 30.811 32.068 3118.342
half 100 1000000 32 1024 InnerProduct col 1.094 30.995 32.089 3116.356
float 100 1000000 32 1024 L2SqrtExpanded row 1.013 53.023 54.035 1850.639
half 100 1000000 32 1024 L2SqrtExpanded row 0.985 53.037 54.022 1851.109
float 100 1000000 32 1024 L2SqrtExpanded col 2.251 48.709 50.961 1962.302
half 100 1000000 32 1024 L2SqrtExpanded col 2.086 48.884 50.971 1961.911
float 100 1000000 256 128 InnerProduct row 0.003 6.092 6.094 16408.592
half 100 1000000 256 128 InnerProduct row 0.003 3.126 3.129 31960.277
float 100 1000000 256 128 InnerProduct col 2.507 6.781 9.289 10765.918
half 100 1000000 256 128 InnerProduct col 1.710 2.488 4.198 23822.428
float 100 1000000 256 128 L2SqrtExpanded row 1.578 5.967 7.544 13255.126
half 100 1000000 256 128 L2SqrtExpanded row 1.250 3.342 4.592 21778.493
float 100 1000000 256 128 L2SqrtExpanded col 3.474 7.944 11.419 8757.537
half 100 1000000 256 128 L2SqrtExpanded col 2.743 4.379 7.122 14041.061
float 100 1000000 256 1024 InnerProduct row 0.003 18.316 18.319 5458.889
half 100 1000000 256 1024 InnerProduct row 0.003 19.368 19.371 5162.445
float 100 1000000 256 1024 InnerProduct col 2.492 17.946 20.437 4892.974
half 100 1000000 256 1024 InnerProduct col 1.710 17.683 19.393 5156.620
float 100 1000000 256 1024 L2SqrtExpanded row 1.585 46.057 47.642 2098.978
half 100 1000000 256 1024 L2SqrtExpanded row 1.259 45.341 46.600 2145.905
float 100 1000000 256 1024 L2SqrtExpanded col 3.480 47.391 50.872 1965.728
half 100 1000000 256 1024 L2SqrtExpanded col 2.736 48.139 50.875 1965.621
float 100 1000000 1024 128 InnerProduct row 0.003 16.497 16.500 6060.686
half 100 1000000 1024 128 InnerProduct row 0.003 4.195 4.198 23822.428
float 100 1000000 1024 128 InnerProduct col 6.655 17.215 23.870 4189.381
half 100 1000000 1024 128 InnerProduct col 4.165 4.262 8.427 11866.377
float 100 1000000 1024 128 L2SqrtExpanded row 3.542 16.509 20.051 4987.345
half 100 1000000 1024 128 L2SqrtExpanded row 2.321 4.400 6.721 14878.718
float 100 1000000 1024 128 L2SqrtExpanded col 9.558 20.924 30.482 3280.646
half 100 1000000 1024 128 L2SqrtExpanded col 8.448 7.159 15.608 6407.093
float 100 1000000 1024 1024 InnerProduct row 0.003 24.618 24.622 4061.449
half 100 1000000 1024 1024 InnerProduct row 0.003 18.342 18.345 5451.117
float 100 1000000 1024 1024 InnerProduct col 6.650 28.540 35.191 2841.661
half 100 1000000 1024 1024 InnerProduct col 4.134 19.459 23.593 4238.568
float 100 1000000 1024 1024 L2SqrtExpanded row 3.538 52.542 56.080 1783.158
half 100 1000000 1024 1024 L2SqrtExpanded row 2.325 48.467 50.791 1968.835
float 100 1000000 1024 1024 L2SqrtExpanded col 9.607 59.168 68.775 1454.023
half 100 1000000 1024 1024 L2SqrtExpanded col 151.098 49.229 200.327 499.184
float 1024 1000000 32 128 InnerProduct row 0.003 20.491 20.494 49966.165
half 1024 1000000 32 128 InnerProduct row 0.003 20.509 20.513 49919.884
float 1024 1000000 32 128 InnerProduct col 1.258 20.298 21.556 47503.756
half 1024 1000000 32 128 InnerProduct col 1.094 20.477 21.571 47471.076
float 1024 1000000 32 128 L2SqrtExpanded row 1.010 23.636 24.646 41548.334
half 1024 1000000 32 128 L2SqrtExpanded row 0.982 20.531 21.513 47598.883
float 1024 1000000 32 128 L2SqrtExpanded col 2.241 23.541 25.782 39718.195
half 1024 1000000 32 128 L2SqrtExpanded col 2.081 21.588 23.669 43263.647
float 1024 1000000 32 1024 InnerProduct row 0.003 79.198 79.201 12929.171
half 1024 1000000 32 1024 InnerProduct row 0.004 78.151 78.155 13102.164
float 1024 1000000 32 1024 InnerProduct col 1.263 77.934 79.198 12929.661
half 1024 1000000 32 1024 InnerProduct col 1.094 80.218 81.312 12593.487
float 1024 1000000 32 1024 L2SqrtExpanded row 1.013 243.951 244.965 4180.195
half 1024 1000000 32 1024 L2SqrtExpanded row 0.974 242.955 243.929 4197.946
float 1024 1000000 32 1024 L2SqrtExpanded col 2.240 247.013 249.253 4108.275
half 1024 1000000 32 1024 L2SqrtExpanded col 2.081 226.590 228.670 4478.060
float 1024 1000000 256 128 InnerProduct row 0.003 38.875 38.879 26338.442
half 1024 1000000 256 128 InnerProduct row 0.003 15.260 15.263 67090.257
float 1024 1000000 256 128 InnerProduct col 2.481 39.226 41.707 24552.388
half 1024 1000000 256 128 InnerProduct col 1.717 19.817 21.534 47553.262
float 1024 1000000 256 128 L2SqrtExpanded row 1.479 45.291 46.771 21893.991
half 1024 1000000 256 128 L2SqrtExpanded row 1.248 20.523 21.771 47035.155
float 1024 1000000 256 128 L2SqrtExpanded col 3.468 45.591 49.059 20872.764
half 1024 1000000 256 128 L2SqrtExpanded col 2.763 22.984 25.747 39772.126
float 1024 1000000 256 1024 InnerProduct row 0.003 73.301 73.304 13969.220
half 1024 1000000 256 1024 InnerProduct row 0.003 53.471 53.474 19149.481
float 1024 1000000 256 1024 InnerProduct col 2.463 72.974 75.437 13574.186
half 1024 1000000 256 1024 InnerProduct col 1.701 51.873 53.574 19113.680
float 1024 1000000 256 1024 L2SqrtExpanded row 1.575 317.503 319.078 3209.245
half 1024 1000000 256 1024 L2SqrtExpanded row 1.253 297.144 298.398 3431.664
float 1024 1000000 256 1024 L2SqrtExpanded col 3.449 320.774 324.223 3158.318
half 1024 1000000 256 1024 L2SqrtExpanded col 2.728 297.009 299.737 3416.330
float 1024 1000000 1024 128 InnerProduct row 0.003 125.358 125.361 8168.406
half 1024 1000000 1024 128 InnerProduct row 0.003 20.999 21.003 48755.522
float 1024 1000000 1024 128 InnerProduct col 6.607 126.106 132.713 7715.899
half 1024 1000000 1024 128 InnerProduct col 4.128 22.579 26.707 38342.069
float 1024 1000000 1024 128 L2SqrtExpanded row 3.537 131.411 134.947 7588.147
half 1024 1000000 1024 128 L2SqrtExpanded row 2.315 26.711 29.026 35278.668
float 1024 1000000 1024 128 L2SqrtExpanded col 9.612 133.920 143.532 7134.308
half 1024 1000000 1024 128 L2SqrtExpanded col 8.525 28.427 36.952 27711.348
float 1024 1000000 1024 1024 InnerProduct row 0.004 171.560 171.563 5968.639
half 1024 1000000 1024 1024 InnerProduct row 0.004 76.682 76.685 13353.282
float 1024 1000000 1024 1024 InnerProduct col 6.641 157.351 163.992 6244.192
half 1024 1000000 1024 1024 InnerProduct col 4.160 82.248 86.408 11850.815
float 1024 1000000 1024 1024 L2SqrtExpanded row 3.541 415.505 419.046 2443.645
half 1024 1000000 1024 1024 L2SqrtExpanded row 2.324 316.662 318.986 3210.173
float 1024 1000000 1024 1024 L2SqrtExpanded col 9.578 417.324 426.902 2398.680
half 1024 1000000 1024 1024 L2SqrtExpanded col 8.509 317.856 326.365 3137.591
Linking https://github.com/rapidsai/cuvs/issues/110
Thanks for providing the benchmarks above @rhdong. For smaller number of queries (e.g. 10) it looks like the half precision is significantly slower than the single-precision. It's very common for these algos to be used in online scenarios where 1 query at a time is used. Any idea why we are seeing this perf degradation and how to fix it?
Thanks for providing the benchmarks above @rhdong. For smaller number of queries (e.g. 10) it looks like the half precision is significantly slower than the single-precision. It's very common for these algos to be used in online scenarios where 1 query at a time is used. Any idea why we are seeing this perf degradation and how to fix it?
It looks like it only happens on Col_Major; let me take a look.
/merge