Curator
Curator copied to clipboard
Check Pytorch cuda context is valid across GPUs
Describe the bug
We have had multiple breakages of CUDA context being only used for GPU 0 in a dask+pytorch environment. Sometimes this can occur due to a library creating a cuda context with pytorch before starting the cluster.
What ends up happening is Pytorch models being deployed on GPU-0 and that issue is hard to debug .
I think a better fix is ensuring we dont fork context if its all ready present for local cuda cluster.
import cupy as cp
cp.cuda.runtime.getDeviceCount()
# import torch
# t = totch.as_tensor([1,2,3])
from dask_cuda import LocalCUDACluster
from distributed import Client
from distributed.diagnostics.nvml import has_cuda_context
import time
def check_cuda_context():
_warning_suffix = (
"This is often the result of a CUDA-enabled library calling a CUDA runtime function before "
"Dask-CUDA can spawn worker processes. Please make sure any such function calls don't happen "
"at import time or in the global scope of a program."
)
if has_cuda_context().has_context:
# If no error was raised, the CUDA context is initialized
raise RuntimeError(
f"CUDA context is initialized before the dask-cuda cluster was spun up. {_warning_suffix}"
)
if __name__ == "__main__":
check_cuda_context()
cluster = LocalCUDACluster(rmm_async=True, rmm_pool_size="2GiB")
client = Client(cluster)
CC: @ayushdg
Moving to next release as not a high priority
Need to Move this ticket to 25.01(January)