dask-cuda
dask-cuda copied to clipboard
Program crashes when running program repeatedly
I am currently in the process of benchmarking DASK-CUDA. In this regard, I am creating larger-than-memory matrices and rotate these by 90 degrees.
After running my program for approximately 30 seconds, my GPU-utility and memory-usage drops to idle state, ~3% for sm and 0% for mem. These metrics are obtained using nvidia-smi
. Furthermore, the program crashes after some time, with the following error message:
distributed.worker_memory - WARNING - unmanaged memory use is high. The may indicate a memory leak or the memory may not be released to the OS. Unmanaged memory: 20.2 GB -- Worker memory limit: 31.27 GB
After this error occurs, the folder dask-worker-space/storage/
contains thousands of %28%27asarray-a3d47595a9f7e8e1604b7b171ed85714%27%2C%200%2C%201%2C%200%29
encoded files making it impossible to re-run the program without deleting the folder manually.
Currently, I am using the CuPY benchmark function in order to correctly open and CUDA-streams to time asynchronous events, as using Python's time module will not correctly time GPU-executed code. The following code will reproduce the problem:
import cupy as cp
import numpy as np
import dask.array as da
from dask_cuda import LocalCUDACluster
from dask.distributed import Client, wait
import rmm
import sys
import cucim.core.operations.spatial as spt
from cupyx.profiler import benchmark
from helpers import *
import time
def generate_matrix(size, work_units):
rs = da.random.RandomState(RandomState=cp.random.RandomState)
rs = rs.randint(low = 0, high = 100_000, size = (size, size, work_units), chunks = 'auto')
rs = rs.map_blocks(cp.asarray)
return rs
def rotate(mat):
y = spt.image_rotate_90(cp.asarray(mat), k = 1, spatial_axis=(1,2)) #We have to map the memory view into an actual array, otherwise rotation will be O(1) time
y.persist()
wait(y)
if __name__ == '__main__':
cluster = LocalCUDACluster('0', rmm_managed_memory=True) ##This means that we use GPU device ID 0, and please give me the rapids memory manager
client = Client(cluster) ##We create a local cluster here
client.run(cp.cuda.set_allocator, rmm.rmm_cupy_allocator) ##We combine cupy and rmm to allocate memory on GPU
rmm.reinitialize(managed_memory=True)
cp.cuda.set_allocator(rmm.rmm_cupy_allocator)
sz = int(sys.argv[1])
wk = int(sys.argv[2])
mat = generate_matrix(sz, wk)
y = benchmark(rotate, (mat , ), n_repeat=20, n_warmup=5)
print(cp.average(y.gpu_times))
To run, python3 rotation.py 4096 1000
https://github.com/rapidsai/dask-cuda/issues/1037#tasklist-block-f6fc536c-c177-40be-b8a1-a6527327d3de
After this error occurs, the folder
dask-worker-space/storage/
contains thousands of%28%27asarray-a3d47595a9f7e8e1604b7b171ed85714%27%2C%200%2C%201%2C%200%29
encoded files making it impossible to re-run the program without deleting the folder manually.
This problem, at least, I think was fixed by the merge of #1023, so if you can install nightly dask-cuda it should go away.
Can you try replacing your generate_matrix with:
def generate_matrix(size, work_units):
rs = da.random.RandomState(RandomState=cp.random.RandomState)
rs = rs.randint(low = 0, high = 100_000, size = (size, size, work_units), chunks = 'auto')
rs = rs.map_blocks(cp.asarray)
rs = rs.persist()
wait(rs)
return rs
(I am not sure this will help ...)
Awesome -- The nightly dask resolved the issue of flooding the folder.
However, when adding your contribution, std out is flooded with:
This may indicate a memory leak or the memory may not be released to the OS; see https://distributed.dask.org/en/latest/worker-memory.html#memory-not-released-back-to-the-os for more information. -- Unmanaged memory: 21.89 GiB -- Worker memory limit: 31.27 GiB
2022-11-11 12:03:15,572 - distributed.worker_memory - WARNING - Unmanaged memory use is high. This may indicate a memory leak or the memory may not be released to the OS; see https://distributed.dask.org/en/latest/worker-memory.html#memory-not-released-back-to-the-os for more information. -- Unmanaged memory: 21.89 GiB -- Worker memory limit: 31.27 GiB
2022-11-11 12:03:15,575 - distributed.worker_memory - WARNING - Unmanaged memory use is high. This may indicate a memory leak or the memory may not be released to the OS; see https://distributed.dask.org/en/latest/worker-memory.html#memory-not-released-back-to-the-os for more information. -- Unmanaged memory: 21.89 GiB -- Worker memory limit: 31.27 GiB
2022-11-11 12:03:16,119 - distributed.worker_memory - WARNING - Unmanaged memory use is high. This may indicate a memory leak or the memory may not be released to the OS; see https://distributed.dask.org/en/latest/worker-memory.html#memory-not-released-back-to-the-os for more information. -- Unmanaged memory: 21.89 GiB -- Worker memory limit: 31.27 GiB
2022-11-11 12:03:16,122 - distributed.worker_memory - WARNING - Unmanaged memory use is high. This may indicate a memory leak or the memory may not be released to the OS; see https://distributed.dask.org/en/latest/worker-memory.html#memory-not-released-back-to-the-os for more information. -- Unmanaged memory: 21.89 GiB -- Worker memory limit: 31.27 GiB
2022-11-11 12:03:16,464 - distributed.worker_memory - WARNING - Unmanaged memory use is high. This may indicate a memory leak or the memory may not be released to the OS; see https://distributed.dask.org/en/latest/worker-memory.html#memory-not-released-back-to-the-os for more information. -- Unmanaged memory: 21.89 GiB -- Worker memory limit: 31.27 GiB
2022-11-11 12:03:16,464 - distributed.worker_memory - WARNING - Unmanaged memory use is high. This may indicate a memory leak or the memory may not be released to the OS; see https://distributed.dask.org/en/latest/worker-memory.html#memory-not-released-back-to-the-os for more information. -- Unmanaged memory: 21.89 GiB -- Worker memory limit: 31.27 GiB
2022-11-11 12:03:16,537 - distributed.worker_memory - WARNING - Unmanaged memory use is high. This may indicate a memory leak or the memory may not be released to the OS; see https://distributed.dask.org/en/latest/worker-memory.html#memory-not-released-back-to-the-os for more information. -- Unmanaged memory: 21.89 GiB -- Worker memory limit: 31.27 GiB
Is there a way to call a method which requires either np.ndarray or cp.ndarray on type Array from DASK? Calling map_blocks(cp.asarray) does not make these functions callable. I guess matrix rotatio is difficult, as we are dealing with chunks...
Is there a way to call a method which requires either np.ndarray or cp.ndarray on type Array from DASK? Calling map_blocks(cp.asarray) does not make these functions callable. I guess matrix rotatio is difficult, as we are dealing with chunks...
Sorry, I think I am not following. When you create your initial matrix:
def generate_matrix(size, work_units):
rs = da.random.RandomState(RandomState=cp.random.RandomState)
rs = rs.randint(low = 0, high = 100_000, size = (size, size, work_units), chunks = 'auto')
rs = rs.map_blocks(cp.asarray)
return rs
You shouldn't need the rs.map_blocks
call. Since you provided a cupy randomstate object, when you made the dask array representing the matrix, each individual chunk is stored using a cupy array.
Your out of memory problem is coming from this bit of code:
def rotate(mat):
y = spt.image_rotate_90(cp.asarray(mat), k = 1, spatial_axis=(1,2)) #We have to map the memory view into an actual array, otherwise rotation will be O(1) time
y.persist()
wait(y)
The first problem is the call to cp.asarray(mat)
. This takes the large matrix that is distributed across multiple GPUs and tries to materialise it as a single cupy array on a single GPU (which is the cause of your out of memory error). Were this to succeed, I think you would subsequently get errors that y.persist()
is an undefined method because cucim
doesn't know anything about dask arrays, so just hands back a cupy array.
You are right that things are more difficult in this setting because of the distributed nature of the array. In this particular case, image_rotate_90
is just a wrapper around cupy.rot90
I think. Fortunately, rot90
is an array method that dask-array implements. So in this case I think you can achieve what you want with:
def generate_matrix(size, work_units):
rs = da.random.RandomState(RandomState=cp.random.RandomState)
rs = rs.randint(low = 0, high = 100_000, size = (size, size, work_units), chunks = 'auto')
return rs
def rotate(mat):
y = da.rot90(mat, ...) # not sure exactly of the arguments
return y.persist()
def bench(mat):
r = rotate(mat)
wait(r)
return r
Hi @wence-, thank you very much for your swift respones. Sorry for my unclear question.
The question is why can I not call spt.image_rotate_90(mat, k = 1, spatial_axis=(1,2))
. The matrix is generated, as you state, using a CuPY random state object, why each chunk is a CuPY array. However, this type-errors.
raise TypeError("img must be a cupy.ndarray or numpy.ndarray")
TypeError: img must be a cupy.ndarray or numpy.ndarray
This is by passing the matrix generated by generate_matrix()
to rotate()
. Therefore, I was wondering how to correctly call methods which require type CuPy ND.array or NumPy ND.arrays on Dask collections correctly. I hope this is more clear now.
Furthermore, the reason as to why I wish to steer clear of da.rot90()
is that the documentation is somewhat wrong, https://github.com/dask/dask/issues/9576. A view is returned, meaning that during benchmarking operations are O(1). I wish to actually return the rotated array and force the execution of the rotation.
Furthermore, the following program:
import cupy as cp
import numpy as np
import dask.array as da
from dask_cuda import LocalCUDACluster
from dask.distributed import Client, wait
import rmm
import sys
import time
from helpers import *
from cupyx.profiler import benchmark
def generate_array(size, work_units):
rs = da.random.RandomState(RandomState=cp.random.RandomState)
rs = rs.randint(low = 0, high = 100_000, size = (size, work_units), chunks = 'auto')
return rs
def multiplication(arr):
array_mult = da.multiply(arr, 42)
return array_mult.persist()
def bench(arr):
r = multiplication(arr)
wait(r)
return r
if __name__ == '__main__':
cluster = LocalCUDACluster('0', rmm_managed_memory=True)
client = Client(cluster)
client.run(cp.cuda.set_allocator, rmm.rmm_cupy_allocator)
rmm.reinitialize(managed_memory=True)
cp.cuda.set_allocator(rmm.rmm_cupy_allocator)
size = int(sys.argv[1])
wk = int(sys.argv[2])
arr = generate_array(size, wk)
y = benchmark(multiplication, (arr, ), n_repeat=20)
print(np.average(y.gpu_times))
Floods the terminal with:
2022-11-15 10:59:11,217 - distributed.preloading - INFO - Creating preload: dask_cuda.initialize
2022-11-15 10:59:11,217 - distributed.preloading - INFO - Import preload module: dask_cuda.initialize
2022-11-15 10:59:45,007 - distributed.utils_perf - WARNING - full garbage collections took 28% CPU time recently (threshold: 10%)
2022-11-15 10:59:48,040 - distributed.utils_perf - WARNING - full garbage collections took 29% CPU time recently (threshold: 10%)
2022-11-15 10:59:51,317 - distributed.utils_perf - WARNING - full garbage collections took 29% CPU time recently (threshold: 10%)
2022-11-15 10:59:55,048 - distributed.utils_perf - WARNING - full garbage collections took 29% CPU time recently (threshold: 10%)
2022-11-15 10:59:59,419 - distributed.utils_perf - WARNING - full garbage collections took 29% CPU time recently (threshold: 10%)
1.4989467407226562
2022-11-15 11:00:04,347 - distributed.utils_perf - WARNING - full garbage collections took 29% CPU time recently (threshold: 10%)
2022-11-15 11:00:04,405 - distributed.batched - ERROR - Error in batched write
But still executes (on the GPU-side) in approximately 1.8 milliseconds. When benchmarking DASK-CUDA, should one include CPU-side execution as well as GPU-side exeuction?
Hi @wence-, thank you very much for your swift respones. Sorry for my unclear question.
The question is why can I not call
spt.image_rotate_90(mat, k = 1, spatial_axis=(1,2))
. The matrix is generated, as you state, using a CuPY random state object, why each chunk is a CuPY array. However, this type-errors.raise TypeError("img must be a cupy.ndarray or numpy.ndarray") TypeError: img must be a cupy.ndarray or numpy.ndarray
This is by passing the matrix generated by
generate_matrix()
torotate()
. Therefore, I was wondering how to correctly call methods which require type CuPy ND.array or NumPy ND.arrays on Dask collections correctly. I hope this is more clear now.
The answer to this is (in general) "you can't". To understand why, let's step back from the specifics of the operations you're doing here and think about what is happening when you make a dask array. Suppose that we build a 2D array using numpy
as the base type with two chunks, a dask array is a meta-object that references the two concrete numpy arrays and advertises a number of properties that make it look like a single numpy array (but the two pieces may even live on two different computers, depending on the cluster setup). Here's a picture from the dask array docs
What does this mean for the implementation of most algorithms? Well, any algorithm that is written purely calling methods of numpy arrays (slicing, in-place binary operations, and so forth) has a chance of working since the dask array object provides implementations of those that do the right thing.
On the other hand, an algorithm that calls a numpy function (e.g. numpy.rot90) has a much lower chance of working (since it probably expects to actually receive a concrete numpy array that it will hand on to the numpy C implementation layer).
Consequently, if we find ourselves wanting to call a numpy function on a dask array, we need to think a little bit. Perhaps it is already implemented in the dask.array
namespace, in which case we should call that function instead. If it is not, then if we are lucky, the operation we want to do operates "pointwise", so we can just apply our numpy function to each chunk (with dask.array.map_blocks
). In the general case, there will be some communication required between the chunks and then we have to think harder about how to implement our algorithm in a distributed memory setting.
Furthermore, the reason as to why I wish to steer clear of
da.rot90()
is that the documentation is somewhat wrong, dask/dask#9576. A view is returned, meaning that during benchmarking operations are O(1). I wish to actually return the rotated array and force the execution of the rotation.
Probably in that case you would want to .copy()
the returned array? (numpy.rot90
also returns a view, if you want to materialise it, then you would need to copy into contiguous storage).
But still executes (on the GPU-side) in approximately 1.8 milliseconds. When benchmarking DASK-CUDA, should one include CPU-side execution as well as GPU-side exeuctio
I think you have an error in your benchmarking script (you should benchmark bench
not multiplication
). The latter just calls persist
so says "please promise that at some point in the future the result of this computation is materialised on the cluster"; whereas the former wait
s on the persist
ed object.
If I make that change, then using python run-benchmark.py 10000 20000
reports an average runtime of about 780ms which seems more reasonable.
In terms of the more philosophical question of which time to include, it depends what you're trying to achieve. In terms of time-to-science, the only really important number is the wallclock time "how long does it take before I have a result I can do something with?". If you're trying to figure out which parts of an analysis are slow, then it is worthwhile recording things in more detail to try and understand where the time is lost.
@wence- Again, thank you very much for your thorough explanations. I am currently figuring out if I should take the time to implement a distributed rotation algorithm, or if I should just materialize the view, as you said. The latter, with your metric of time-to-science is probably the way to go.
The reason as to my many questions is that I doing a research project, and I by no means wish to incorrectly time and benchmark your functions, and I wish to stay as close as possible to the philosophy behind DASK and CUDA-DASK. I will not be using the wallclock time, as such it does not capture how long the GPU-execution specifically takes. I wish to very precisely provide a pointer as when, and if, a workload should be offloaded to the GPU and when it should reside on the CPU. Naturally, this does not make too much sense when most GPU's are memory-limited by, i.e. 6GB, why I use DASK to do these larger-than-memory arrays. Therefore, in order to provide such pointers, fine granularity of timing is required, i.e. how much time is spent on transferring memory between the host and device? And also, how much time is spent purely on the device to execute the instructions, such as a multiplication instruction.
In order to obtain the GPU-side execution time I use the CuPY benchmarking tool. I see no reason as to why this should not work for DASK-CUDA, as this simply creates a CUDA stream and records the events on that stream and closes it.
As for the bug -- you are entirely correct -- that is a typo on my side, and I now also obtain more reasonable results when benchmarking. Thank you for your help and your feedback.
With respect to implementing a distributed rotation algorithm or materializing the view, I too would go with the latter. It's entirely possible that what you would end up implementing does the same set of communications under the hood anyway.
The reason as to my many questions is that I doing a research project, and I by no means wish to incorrectly time and benchmark your functions, and I wish to stay as close as possible to the philosophy behind DASK and CUDA-DASK. I will not be using the wallclock time, as such it does not capture how long the GPU-execution specifically takes. I wish to very precisely provide a pointer as when, and if, a workload should be offloaded to the GPU and when it should reside on the CPU. Naturally, this does not make too much sense when most GPU's are memory-limited by, i.e. 6GB, why I use DASK to do these larger-than-memory arrays. Therefore, in order to provide such pointers, fine granularity of timing is required, i.e. how much time is spent on transferring memory between the host and device? And also, how much time is spent purely on the device to execute the instructions, such as a multiplication instruction.
For these kind of fine-grained questions, I would turn to a more fully-fledged profiler like Nsight Systems.
Dask also has some facility for generating performance reports, their documentation is worthwhile reading.
In order to obtain the GPU-side execution time I use the CuPY benchmarking tool. I see no reason as to why this should not work for DASK-CUDA, as this simply creates a CUDA stream and records the events on that stream and closes it.
This should work, though be wary of the case when you are using more than one GPU in a cluster, make sure to indicate all the devices you intend to run on to the benchmark function.
"This should work, though be wary of the case when you are using more than one GPU in a cluster, make sure to indicate all the devices you intend to run on to the benchmark function."
In the initial run of this project, I will keep to a single GPU. Although I have more GPU's available, I think it is also worthwhile giving pointers when only a single GPU is available.
As for the fine-grained profiling, Nsight is indeed a good tool. Again, in this first iteration, I am only using the timing functions in order to compute throughput. This is however not the main investigation of my project, I am also querying the GPU's in order to gain insights into GPU watt usage, to see the trade-off in throughput and energy efficiency, in order to also highlight the energy-cost of using GPU's for small workloads, and for which workloads the energy-throughput efficiency is worthwhile.
In complete honesty, the only reason I am very much interested in the fine-grained timings is in order to compute the wattage, as W = J / S, why especially the timing aspect of this is extremely important :-)
A last question -- if I altogether avoid DASK-CUDA, and implement it only using CuPy
import cupy as cp
import sys
from cupyx.profiler import benchmark
from helpers import *
from timeit import default_timer as timer
import time
import cupyx
import math
def generate_array(size, work_units):
rand = cp.random.default_rng() #This is the fast way of creating large arrays with cp
arr = rand.integers(0, 100_000, (size, work_units)) #Create array
return arr
def multiplication(arr):
y = cp.multiply(arr, 42) ## Multiply by 42, randomly chosen number
return y
def bench(arr):
r = multiplication(arr)
return r
if __name__ == '__main__':
sz = int(sys.argv[1])
wk = int(sys.argv[2])
arr = generate_array(sz, wk)
y = benchmark(bench, (arr, ), n_repeat= 20)
print(np.average(y.gpu_times))
This executes in 0.011635497713088989ms, whilst the other executes in 0.8 ms. How come there are so large differences as per the GPU times?