cub icon indicating copy to clipboard operation
cub copied to clipboard

Allow iterators in cub::DeviceRadixSort

Open zasdfgbnm opened this issue 3 years ago • 6 comments

Currently, cub::DeviceRadixSort only support operating on pointers

template<typename KeyT , typename ValueT >
static CUB_RUNTIME_FUNCTION
cudaError_t      SortPairs (void *d_temp_storage, size_t &temp_storage_bytes, const KeyT *d_keys_in, KeyT *d_keys_out, const ValueT *d_values_in, ValueT *d_values_out, int num_items, int begin_bit=0, int end_bit=sizeof(KeyT)*8, cudaStream_t stream=0, bool debug_synchronous=false)

It would be good if the d_values_in could be an iterator.

One use case is https://github.com/pytorch/pytorch/pull/53841, in this PR, we are working on a sorting problem where the input keys are random numbers, and input values are 0, 1, 2, 3, ..., N. Currently, we have to generate a memory buffer to store these 0, 1, 2, ..., N, which is not optimal. It would be nice if we can do something like:

cub::CountingInputIterator iter(0);
cub::DeviceRadixSort::SortPairs(..., /*d_values_in=*/iter, /*d_values_out=*/buffer, ...);

zasdfgbnm avatar Mar 17 '21 04:03 zasdfgbnm

Have you tried the thrust::sort functions?, they allow to do what you are asking for. For example: https://thrust.github.io/doc/group__sorting_gaec4e3610a36062ee3e3d16607ce5ad80.html Just curious to know if you have some experience with timings for your use case between cub/thrust. I guess thrust is probably going to call cub anyway. If you are worried about temporary memory allocation when using thrust, you can use the policy argument, like here: https://github.com/NVIDIA/thrust/blob/main/examples/cuda/custom_temporary_allocation.cu

RaulPPelaez avatar Mar 17 '21 09:03 RaulPPelaez

@RaulPPelaez Yes I am aware of thrust::sort. I was actually using thrust::sort and want to migrate to cub. I don't know the perf number, but thrust sort has some device host synchronizations which drives me this migration.

zasdfgbnm avatar Mar 17 '21 15:03 zasdfgbnm

I agree that this should be done if possible and can look into prioritizing this after I finish updating our benchmarking infrastructure.

Pinging @dumerrill and @canonizer since they know the radix sorting code best -- are ya'll aware of any reasons that this wouldn't work? Other device algorithms support iterators, so I'm curious if there's a reason DeviceRadixSort is different.

alliepiper avatar Mar 17 '21 16:03 alliepiper

Pinging @dumerrill and @canonizer since they know the radix sorting code best -- are ya'll aware of any reasons that this wouldn't work? Other device algorithms support iterators, so I'm curious if there's a reason DeviceRadixSort is different.

The main difference being that DeviceRadixSort is a multi-pass algorithm. One could account for taking an arbitrary iterator, but the implementation will require special treatment. Currently, internally, DeviceRadixSort uses a DoubleBuffer with two ptr members (swapping the two buffers with each sorting pass). To account for an iterator, the first sorting pass will have the iterator as an input and write the materialised results to memory. From then on, we can continue with the usual DoubleBuffer logic.

Two options to implement this:

  • materialise the input iterator prior to the sort (iirc, this is what thrust::sort does): Less efficient but easy to implement.
  • have a special treatment on the first sorting pass: More implementation effort, longer compilation times (as first sorting pass has different template specialisation), but more efficient (saving N memory writes).

elstehle avatar Mar 18 '21 07:03 elstehle

Thanks -- that's a good point.

The double buffer approach has some other issues, too -- the current implementation casts away the const-ness of the input and then writes to the const input memory as part of the double buffering. We should fix that when addressing this issue, since the fix would reuse the same code path.

alliepiper avatar Mar 18 '21 19:03 alliepiper

I am working on this in https://github.com/NVIDIA/cub/pull/374

zasdfgbnm avatar Sep 14 '21 00:09 zasdfgbnm