raft icon indicating copy to clipboard operation
raft copied to clipboard

Add 1xtfloat capability to pairwise_matrix distance computations

Open ahendriksen opened this issue 2 years ago • 4 comments

This PR adds the possibility to use 1xtfloat in the pairwise matrix computations of raft::distance.

When 1xtfloat is enabled, the throughput more than triples compared to using 3xtfloat.

Benchmarks below were taken on H100 (unlocked clocks, SXM). The distance computed was the square L2 expanded distance. Therefore, one core_op corresponds to one fused multiply add.

Time Iterations 1xtfloat BW core_ops/s k m n
0.050 ms 13697 249.475G/s 21.2885T/s 1024 1024 1024
0.087 ms 8071 242.288G/s 24.8103T/s 2.048k 1024 1024
0.160 ms 4392 235.911G/s 26.8414T/s 4.096k 1024 1024
0.297 ms 2361 240.406G/s 28.9619T/s 8.192k 1024 1024
0.265 ms 2633 523.272G/s 64.9491T/s 1024 1024 16.384k
0.490 ms 1430 428.423G/s 70.1928T/s 2.048k 1024 16.384k
0.945 ms 741 372.783G/s 72.7105T/s 4.096k 1024 16.384k
1.85 ms 378 344.673G/s 74.3042T/s 8.192k 1024 16.384k
0.132 ms 5304 95.2914G/s 8.13154T/s 1024 1024 1024
0.249 ms 2808 84.1437G/s 8.61631T/s 2.048k 1024 1024
0.484 ms 1445 77.9194G/s 8.8655T/s 4.096k 1024 1024
0.955 ms 733 74.636G/s 8.99144T/s 8.192k 1024 1024
0.816 ms 857 169.614G/s 21.0526T/s 1024 1024 16.384k
1.58 ms 442 132.367G/s 21.687T/s 2.048k 1024 16.384k
3.13 ms 224 112.65G/s 21.9721T/s 4.096k 1024 16.384k
6.21 ms 113 102.705G/s 22.141T/s 8.192k 1024 16.384k

ahendriksen avatar May 08 '23 08:05 ahendriksen

@cjnolet : I have implemented the 1xtfloat distance, but it is not yet exposed in the public API. The distance API is getting a bit unwieldy. I see the following options to expose the 1xtfloat in the API:

  1. Add another overload of raft::distance::distance that takes an {L2, cosine, etc..}_options struct.
  2. Option 1 and also remove many overloads of raft::distance::distance.
  3. Interrogate the NVIDIA_TF32_OVERRRIDE environment and/or add a flag to raft::resources to enable 1xtfloat (as discussed in https://github.com/rapidsai/raft/issues/1393)

Do you have any thought on this? What has your preference?

ahendriksen avatar May 08 '23 08:05 ahendriksen

@benfred : Related to https://github.com/rapidsai/raft/issues/852, I have drafted a type to describe the L2 distance options. It describes:

  • Whether to compute the squared or true L2 distance
  • How to compute the L2 distance (expanded/unexpanded, 3xtfloat, 1xtfloat, depend on environment variables)

The docstrings in the code explain how each option should work. Please let me know:

  • If this is how you envisioned the distance types => if so, I can expand to other distances as well.
  • If you have any comments on the current design.

ahendriksen avatar May 08 '23 08:05 ahendriksen

@ahendriksen @tfeher are we still planning to make progress on this feature? I'm doing a little housekeeping on the PRs and just want to make sure the PRs we are keeping open are still valid.

cjnolet avatar Jan 10 '24 14:01 cjnolet

The main question here what is the best mechanism for the user to opt in/out of 1xTF32 computation. @vinaydes is working on the same question related to #1892. Let's wait until that is fixed, and afterwards we shall return to this PR.

Since @ahendriksen is busy with other tasks, we need someone else to continue this. Assigning this to myself for now, we will revisit availability once #1892 is solved.

tfeher avatar Jan 10 '24 19:01 tfeher