Halide
Halide copied to clipboard
Use thread-local context and stream in CUDA runtime.
Workaround for issue #6255
This change makes it possible to set a thread-local CUDA context and stream by explosing the following functions:
halide_cuda_set_current_context
and halide_cuda_set_current_stream
. Executing kernels and buffer copies on custom stream and context is essential for work on multi-gpu systems.
TODO: Facilitate calling of the CUDA runtime functions in JIT mode. Now these functions are relatively easily accessible only in AOT, with JIT it's necessary (?) to go over JIT runtime modules and look for symbols.
NOTE: This does not work when there's a parallel
schedule outside a GPU schedule - threads from Halide's internal thread pool will still see a global context, possibly on a different device.
I don't think we could merge something like this, because the thread-local assumption is pretty bogus outside of very specific schedules. I think we're going to have to strip all global state from the runtime and make it possible to pass in as a context object. You'd manipulate this context object directly to set the cuda stream and cuda context.
Given that the PR is a draft I guess the intent is not to merge it, but rather to illustrate what you had to do as a workaround for your use-case?
Yes, this PR is merely illustrating what works for what I'd call a typical GPU schedule, but I'm very much aware of its lack of generality, as I've stated in the discussion in the related issue. I have more workaround ideas, which I'll share in the original issue/thread.
Is this PR still active? Should it be closed?
Is this PR still active? Should it be closed?
I'd like to keep it open (as a draft) until all functionality that can be achieved with this PR is available. #6313 is a starting point, but it's still missing some features.
What features are still missing? It should now be possible to plumb through a custom user_context everywhere, and hence to use a custom cuda context and stream everywhere.
What features are still missing? It should now be possible to plumb through a custom user_context everywhere, and hence to use a custom cuda context and stream everywhere.
Hmm... there's at least one more thing: host_supports_target_device
. It creates a CUDA context as a side-effect. I wonder if there are other places that would cause the default context to be created (even if it's transient and the fact that something happens on that context is not visible to the user).
Creating this context is unacceptable, because just the mere fact that the context is there puts a heavy toll on GPU memory - I think it depends on the total size of kernels used in the CUDA runtime API. For example, when I was running a ResNet50 training on PyTorch, creating a context on a device took additional 300 MB of device memory. Now, imagine a system with 16 GPUs where 16 processes create an unnecessary context on device 0 - this is 4.8 GB wasted.
That happens if you use a target string like host-cuda, without specifying a compute capability. Halide has to take a guess at which device you want to use so that it can query the capabilities of the device. Do you already have some other way to determine the host cuda capability? If so, try using the target string host-cuda-cuda_capability_?? and see if we still generate a needless context.
It's not a matter of compute capability - the systems in question have multiple GPUs of exactly the same kind. These are either large compute and graphics workstations or compute servers (NVIDIA DGX series, for example).
host_supports_target_device
itself is probably not that problematic - I haven't found any calls to it from within Halide, so I assume it's user's job to call it if necessary (and if the user already has a CUDA context, then calling it is likely is not necessary). It's rather a question whether there are any other places that might create a default CUDA context as a side effect. I've found this one and the way a default context can slip in is easy to overlook, so I'm a bit concerned that there might be more places like this - and perhaps some of them are harder to avoid.
I think the last remaining place to worry about really is in parsing "host-cuda" specifically, but it sounds like what you want is for any attempt by the Halide runtime to create a cuda context to throw an error instead, so that you don't have to rely on being strict about never using "host-cuda" and always passing a user_context.
All calls to create a context bottleneck through this one function: https://github.com/halide/Halide/blob/master/src/JITModule.cpp#L590 The problematic call into the runtime is here: https://github.com/halide/Halide/blob/master/src/Target.cpp#L238 Note the first arg is hard-coded to nullptr. That means the second case in the handler in JITModule.cpp always triggers. We can override the second case in the handler to be some function that throws an error (returns a non-zero value) using JITHandlers JITSharedRuntime::set_default_handlers(const JITHandlers &handlers). This must be done first, before any "host-cuda" target string usage.
If you set the default acquire_context handler to throw an error using the method above, it should be impossible for Halide to create its own cuda context.
I now sorta think that having some Target strings quietly instantiate GPU contexts to do sniffing is a bug and not a feature -- IMHO it's not at all obvious that host-cuda
does this, nor is it the case that this is always what you want.
Thinking out loud, maybe we want some additional "feature" string that is used to indicate that we specifically want capability detection to be done (e.g. host-cuda-detect_gpu_capability
). This is wonky, but not entirely unprecedented; we already support parsing trace_all
as a feature, which really expands into the union of all the other trace_
features.
I'd argue that additional string is "host". "host" means please sniff my system. E.g. call cpuid.
Anything additional on top of that has a discoverability problem: We'd revert to the days of people just believing Halide is slow for cuda because they didn't know to turn on the hidden magic flag.
Needing to instantiate a context just to resolve the capability is indeed a PITA though. Maybe we can delay resolution of the host cuda capability, e.g. by unpacking host-cuda into x86-64-linux-avx2-cuda-cuda_capability_host in the target parser. The issue is that we need to know it at lowering time, which is pretty early... would delaying resolution until lowering be an advantage?
That issue is orthogonal though, so let's discuss it in: https://github.com/halide/Halide/issues/6448