cuCollections icon indicating copy to clipboard operation
cuCollections copied to clipboard

[FEA] Add a `find_if_exists` function to the static map implementation

Open jrhemstad opened this issue 4 years ago • 3 comments

Is your feature request related to a problem? Please describe.

Given a set of probe keys, I would like to gather all of the matching <key,value> pairs from a map.

@nsakharnykh @Nicolas-Iskos @allisonvacanti and myself spent some time discussing how to handle the fact that the size of the output cannot be known a priori in the general case. We decided we would allow a user to specify an output buffer that is potentially too small to contain all of the matching (key,value) pairs. In the event such an "overflow" is detected (more on how to detect this below), we do not throw an exception. Instead, we exploit the fact that we can tell the user exactly how many matches there are so they can allocate a properly sized output buffer.

Instead of simply returning a std::size_t with the actual number of matches, @allisonvacanti suggested that we use something like std::expected (which hasn't made it's way into the standard yet) that indicates both if an overflow occurred and the total number of matches. This is potentially less of a foot-gun as it more explicitly informs the user that an overflow occurred (otherwise they would have to remember to compare the returned size_t against distance(output_begin, output_end) to detect the overflow).

The API would look something like this:

 /**
   * @brief Finds all (key,value) pairs whose keys are equal to a set of probe keys.
   *
   * For each key `*(probe_first + i)`, if it exists in the map, copies the corresponding
   * (key,value) pair to the output. The order of the key,value pairs in the output is
   * non-deterministic. If there are repeated values in `[probe_first, probe_last)`, there will be
   * repeated pairs in the output.
   *
   * In general, the number of matches `n` is not known ahead of time. In the worst case, the number
   * of elements in the output is equal to `std::distance(probe_first, probe_last)`. This presents
   * the potential for a large memory overhead as, in practice, the number of matches may be far
   * less than the number of probe keys. Therefore, it is well-defined to pass in an output
   * iterator range that is potentially smaller than the actual number of matches.
   *
   * If `std::distance(output_first, output_last) < n`, i.e., the output iterator range is not large
   * enough to fit all matches, then `[output_first, output_last)` will be filled with a valid set
   * of (key,value) pairs, but plainly does not contain _all_ matches. In the event such an
   * "overflow" occurs, `find_if_exists` returns an `expected<size_t,size_t>` object that indicates:
   * 1) whether or not an overflow occurred, 2) the total number of keys in `[probe_first,
   * probe_last)` that were present in the map.
   *
   * In the event the original output iterator range was not large enough, the `expected_proxy`'s
   * reported number of matches can be used to create a new output iterator range that is exactly
   * large enough to contain all of the matches in a future call to `find_if_exists`.
   *
   * TODO: Describe detailed semantics of the "expected_proxy" object
   *
   * @tparam InputIt Input iterator whose `value_type` is convertible to `Key`
   * @tparam OutputIt Output iterator whose `value_type` is convertible from `pair<Key,Value>`
   * @tparam Hash
   * @tparam KeyEqual
   * @param probe_first Beginning of the probe key range
   * @param probe_last End of the probe key range
   * @param output_begin Beginning of the output pair range
   * @param output_end End of the output pair range
   * @param hash
   * @param key_equal
   * @return The number of keys in `[probe_first, probe_last)` that exist in the map
   */
  template <typename InputIt, typename OutputIt, typename Hash, typename KeyEqual>
  expected_proxy<std::size_t, std::size_t> find_if_exists(InputIt probe_first,
                                                          InputIt probe_last,
                                                          OutputIt output_first,
                                                          OutputIt output_last,
                                                          Hash hash,
                                                          KeyEqual key_equal)

Questions/Points of discussion:

  • There are prototype implementations of std::expected in the wild, but they are more complicated than what we need for this usecase. I propose we make our own simple "expected-like" object tailored to this usecase. We can eventually replace that with thrust::expected down the road.

  • I documented that the order of the output is non-deterministic. This avoids tying our hands in how we implement it, but it does preclude an interesting use case that @allisonvacanti mentioned:

I'm starting to like the soft overflow more. It would also let someone just reuse a smaller buffer to stream values out of the hash map. If it overflows, just handle what you have, update your query, and call find_if_exists again.

That would only work if the returned (key,value) pairs are in the same order as their corresponding probe_keys.

  • The naive implementation of find_if_exists can be done by just doing a normal "bulk find" that returns a value for each probe key, and for probe keys that don't exist, it just returns the empty value sentinel. Then, you can do a stream compaction/copy_if of only the non-empty values.
    • However, I think there's an opportunity to do better by fusing the find + stream compaction operations. The naive approach would be to keep a single global memory offset counter, and as you find a match, you atomically increment the counter and use it's return value as the location to write your found match. This suffers from atomic contention and non-coalesced writes.
      • I have two ideas that should be more optimal :
        1. We could extend the above idea and keep a shared memory buffer and atomic counter per block. When a thread finds a match, it writes it to the shared memory buffer. As that buffer gets full, you can flush from shared to global using coalesced writes. To detect overflow, when a block goes to flush its shared memory buffer, it checks if the flush will overflow the output. If yes, it atomically sets a flag indicating an overflow occurred and does a partial flush. All other blocks not yet retired will skip their flush from shmem to global memory. The output order is non-deterministic.
        2. We run a kernel to compute the number of matches per thread block, then compute a scan of the sizes to get the write offset location per block. Threads within a block coordinate using a shared memory atomic to write to their window in global memory. Or we can This has the advantage of the fact that we can detect "overflow" outside of a kernel as the last element of the scan tells us the total number of matches.
          • There's an opportunity here to make the output order deterministic by doing a per-block scan to determine the output location for each thread instead of using an atomic. But this is likely going to be slower than the above approach.

jrhemstad avatar Jul 07 '20 17:07 jrhemstad

Building off of my last idea about a deterministic ordering of outputs. I have an idea that requires O(n + num_blocks) temporary memory. My primary goal was to only do one find per probe key as it is likely the main bottleneck.

1.) First, for each key in [probe_first, probe_last), compute an array of in32_t indices where indices[i] indicates the index of the slot that contains *(probe_first + i). If the key doesn't exist, the index is equal to -1.

  • Simultaneously, compute the number of matching keys owned by each block, so block_counts[j] indicates the number of matching keys owned by block j.

2.) Compute exclusive scan of block_counts in-place such that block_offsets[j] indicates the starting position of block j in the output

  • block_offsets[num_blocks] indicates the total number of matches (to report in the returned expected_proxy object)

3.) Final kernel computes each threads location in the output and copies to the output:

  • Block j performs block sum scan of if indices[tid] > 0
  • ouput_index = block_scan[threadIdx.x] + block_offset[j] indicates the location in output
  • If (indices[tid] > 0) and (output_index < thrust::distance(output_first, output_last)- 1), copy the (key/value) pair from indices[tid] to the output

You could do the same algorithm with O(num_blocks) temporary memory, but it requires two find operations per probe key, one in step 1 when computing block_counts, and another in step 3 when performing the block scan.

jrhemstad avatar Jul 07 '20 20:07 jrhemstad

@nsakharnykh @harrism @kkraus14 I just realized that if we make find_if_exists preserve the input probe key order (as I outlined above), we can trivially use this to implement an order preserving join algorithm.

Edit: The above algorithm is for a map. A join operation requires a multi-map. I think we can extend the same idea for a multi-map, but would require more thought. So it won't apply to Join as trivially as I thought.

jrhemstad avatar Jul 07 '20 20:07 jrhemstad

In thinking about how to make this same algorithm work for a _multi_map, I think something like @allisonvacanti's idea to return an iterator to the last probe key that was successfully gathered from the multimap would be a good idea for enabling a streaming multimap case.

jrhemstad avatar Jul 07 '20 21:07 jrhemstad