candle icon indicating copy to clipboard operation
candle copied to clipboard

Add support for argsort with no shared memory usage

Open EricLBuehler opened this issue 1 year ago • 2 comments

As discussed in #2361, our current argsort implementation does not work on CUDA for large vectors because we use a bitonic sort implementation, which requires shared memory. For some n x m matrix (or really anything x m), shared memory scales with m in our bitonic sort implementation. In fact, it is calculated as sizeof(u32) * m. However, CUDA imposes a limit on shared memory size (depends on the architecture, but is clearly a limiting factor).

This PR sketches an argsort which runs only in global memory. However, I haven't found a way to implement something more parallel because of the limitation of no shared memory. As a result, we use a simple bubble sort, which gives horrible performance, especially when using argsort in sampling where m is large.

@gabrielmbmb, I was wondering if you could review and may know of a fasterway to do this? I considered implementing a merge sort so that each thread will merge sort one row, but that still gives $n \log{n}$ VS the bitonic sort's ${\log}^{2}{n}$.

The PR can be tested with:

cargo test --features cuda --test tensor_tests -- asort_very_big

There is still some cleanup to do before merge (for example, I modified the settings.json for testing and some other things).

EricLBuehler avatar Aug 18 '24 01:08 EricLBuehler

I'm not sure we want to replace the implementation, the goal of the current bitonic implementation is to allow for things like mixture of experts where the number of elements to sort is small and we want to focus on speed (it's actually the same as in llama.cpp), we should certainly document the current limitation though.

LaurentMazare avatar Aug 18 '24 05:08 LaurentMazare

Yeah, we probably should not replace the bitonic implementation, this PR adds a separate kernel which will be invoked only when the shared memory would exceed some threshold.

But, given that this change causes it to be so slow, perhaps it would be better to just document/return an error if the shared memory allocation would be excessive for now?

EricLBuehler avatar Aug 18 '24 17:08 EricLBuehler

Closing to avoid excessive stagnant PRs and due to discussion.

EricLBuehler avatar Sep 25 '24 01:09 EricLBuehler