TensorComprehensions icon indicating copy to clipboard operation
TensorComprehensions copied to clipboard

Tuner timeout

Open math-fehr opened this issue 7 years ago • 3 comments

Implements timeout for cuda backend using mapping option. Also adds a flag to change the default mapping option of the timeout.

The aim of this PR is to allow the use of timeouts in the autotuner, where sometimes kernel of 1s can appear where 5ms can be achieved. As for now, the timeout flag can be used to set a timeout for all produced kernels in the autotuner.

Closes #394 Tag #381

math-fehr avatar Jun 07 '18 09:06 math-fehr

The general idea is important and if it already works it is great. However it comes with too much technical debt to my taste. Let's try to:

  1. use the clock instruction instead of inserting new variables in memory and using a feature that is explicitly discouraged by NVIDIA
  2. not change the proto and the executors API for this, we JIT compile code let's hardcode the value directly in the string; there is no value in passing this as an extra parameter in RTC and significantly complexify the compilation flow.
  3. I would just insert the test above thread mapping nodes and call it a day + have a sanity check that a block should not execute more than a certain number of iterations (we can then prune early)

I would much rather a global variable (as much as I hate them) than increase the complexity of the APIs. But maybe with the use of the clock instruction and the heuristics above, we don't even need to turn it off during tuning.

nicolasvasilache avatar Jun 10 '18 19:06 nicolasvasilache

  1. I tried with clock function, but it has a main issue: it is multiprocessor dependent, which means that you cannot 'synchronize' the blocks to know when the kernel started. clock allows to have a timeout in a block, while we may want to have a timeout for the kernel. The main difference here is that we can have an order of magnitude of difference. If the firsts blocks timeout, the next blocks will likely timeout, but if they aren't schedule yet on the device, they simply don't know if they should timeout earlier or not. On the other hand, the timestamp function allows a synchronization so each block know when the kernel started at the start of the block. But yes, one of the big problem with the timestamp is that we don't currently know if it will react the same way on all GPUs.

  2. While I agree that this code adds a level of complexity that is not wanted, I think that hardcoding the value in the kernel might pose other problems. The idea is that it would be nice to use the timeout in the tuner, and change the timeout over time in a generation to reduce the tuning time for that generation (which can be really long for the first generation). But, to do that, it would be nice to be able to change the value when executing the code, instead of when compiling the code. I implemented it in a way that it would be easy to modify the code to change the timeout when launching the kernel.

  3. Yes, I will insert the test above the thread mapping nodes instead of after. I think it is really difficult to know when there is "to much" iterations in a kernel. For instance, in the most extreme cases, there can be one or two orders of magnitude between using uncoalesced global memory and using registers, for the same number of iterations. More than that, the kernel execution time depends a lot on how many blocks are effectively scheduled at the same time on the device, or if the code is well pipelined. To know when to prune early, we would need to have a sort of performance model, which would require too much work I believe.

math-fehr avatar Jun 11 '18 09:06 math-fehr

I tried with clock function, but it has a main issue: it is multiprocessor dependent

Should be pretty easy to transform your kernel timeout into a per block timeout: we always know statically the number of SMs and the number of blocks. So something like: ceil((timeout * #sms) / #blocks) should do it. Ok it won't be a perfect approx but I wouldn't hesitate for 1 second killing off the complexity. And for the level of granularity we want to catch with this timeout this seems way sufficient to me

The idea is that it would be nice to use the timeout in the tuner, and change the timeout over time in a generation to reduce the tuning time for that generation

We compile a new version for each set of options at each generation, so I don't see the issue, just impl a dynamic timeout in the tuner and hardcode it. I see that one could object that it would prevent memoizing the generated cuda / ptx during tuning but that is premature optimization IMO. The value of keeping a simple and clean abstraction is orders of magnitude more important.

I think it is really difficult to know when there is "to much" iterations in a kernel

Agreed, OTOH I would argue that we don't need something precise, just something that works "well enough". The whole point here is to skip catastrophically bad cases that are so bad that even our pruning with the chainsaw didn't catch (or at least on the first iteration).

Additionally, @ftynse it seems some of the issues this PR wants to catch is on the first iteration where there is no "best kernel" to compare to and we just end up executing really bad ones multiple times. How about putting in the timeout value in the pruning function? This would guarantee we would only execute the bad function once which already catches 90%+ of the useful cases. I still think the value of having a way to interrupt kernels is important but simplicity first without any hesitation.

have a sort of performance model, which would require too much work I believe

we're working on learning some :) def don't want to engineer those.

nicolasvasilache avatar Jun 11 '18 12:06 nicolasvasilache