Add 1xtfloat capability to pairwise_matrix distance computations
This PR adds the possibility to use 1xtfloat in the pairwise matrix computations of raft::distance.
When 1xtfloat is enabled, the throughput more than triples compared to using 3xtfloat.
Benchmarks below were taken on H100 (unlocked clocks, SXM). The distance computed was the square L2 expanded distance. Therefore, one core_op corresponds to one fused multiply add.
| Time | Iterations | 1xtfloat | BW | core_ops/s | k | m | n |
|---|---|---|---|---|---|---|---|
| 0.050 ms | 13697 | ✅ | 249.475G/s | 21.2885T/s | 1024 | 1024 | 1024 |
| 0.087 ms | 8071 | ✅ | 242.288G/s | 24.8103T/s | 2.048k | 1024 | 1024 |
| 0.160 ms | 4392 | ✅ | 235.911G/s | 26.8414T/s | 4.096k | 1024 | 1024 |
| 0.297 ms | 2361 | ✅ | 240.406G/s | 28.9619T/s | 8.192k | 1024 | 1024 |
| 0.265 ms | 2633 | ✅ | 523.272G/s | 64.9491T/s | 1024 | 1024 | 16.384k |
| 0.490 ms | 1430 | ✅ | 428.423G/s | 70.1928T/s | 2.048k | 1024 | 16.384k |
| 0.945 ms | 741 | ✅ | 372.783G/s | 72.7105T/s | 4.096k | 1024 | 16.384k |
| 1.85 ms | 378 | ✅ | 344.673G/s | 74.3042T/s | 8.192k | 1024 | 16.384k |
| 0.132 ms | 5304 | 95.2914G/s | 8.13154T/s | 1024 | 1024 | 1024 | |
| 0.249 ms | 2808 | 84.1437G/s | 8.61631T/s | 2.048k | 1024 | 1024 | |
| 0.484 ms | 1445 | 77.9194G/s | 8.8655T/s | 4.096k | 1024 | 1024 | |
| 0.955 ms | 733 | 74.636G/s | 8.99144T/s | 8.192k | 1024 | 1024 | |
| 0.816 ms | 857 | 169.614G/s | 21.0526T/s | 1024 | 1024 | 16.384k | |
| 1.58 ms | 442 | 132.367G/s | 21.687T/s | 2.048k | 1024 | 16.384k | |
| 3.13 ms | 224 | 112.65G/s | 21.9721T/s | 4.096k | 1024 | 16.384k | |
| 6.21 ms | 113 | 102.705G/s | 22.141T/s | 8.192k | 1024 | 16.384k |
@cjnolet : I have implemented the 1xtfloat distance, but it is not yet exposed in the public API. The distance API is getting a bit unwieldy. I see the following options to expose the 1xtfloat in the API:
- Add another overload of
raft::distance::distancethat takes an{L2, cosine, etc..}_optionsstruct. - Option 1 and also remove many overloads of
raft::distance::distance. - Interrogate the
NVIDIA_TF32_OVERRRIDEenvironment and/or add a flag toraft::resourcesto enable 1xtfloat (as discussed in https://github.com/rapidsai/raft/issues/1393)
Do you have any thought on this? What has your preference?
@benfred : Related to https://github.com/rapidsai/raft/issues/852, I have drafted a type to describe the L2 distance options. It describes:
- Whether to compute the squared or true L2 distance
- How to compute the L2 distance (expanded/unexpanded, 3xtfloat, 1xtfloat, depend on environment variables)
The docstrings in the code explain how each option should work. Please let me know:
- If this is how you envisioned the distance types => if so, I can expand to other distances as well.
- If you have any comments on the current design.
@ahendriksen @tfeher are we still planning to make progress on this feature? I'm doing a little housekeeping on the PRs and just want to make sure the PRs we are keeping open are still valid.
The main question here what is the best mechanism for the user to opt in/out of 1xTF32 computation. @vinaydes is working on the same question related to #1892. Let's wait until that is fixed, and afterwards we shall return to this PR.
Since @ahendriksen is busy with other tasks, we need someone else to continue this. Assigning this to myself for now, we will revisit availability once #1892 is solved.