dask-cuda
dask-cuda copied to clipboard
Add cli option to enable pytorch to use same memory pool as rapids.
Currently we need to below to set rmm to use pytorch pool on a dask-cuda cluster. We should do this via a cli
# Making PyTorch use the same memory pool as RAPIDS.
def _set_torch_to_use_rmm():
"""
This function sets up the pytorch memory pool to be the same as the RAPIDS memory pool.
This helps avoid OOM errors when using both pytorch and RAPIDS on the same GPU.
See article:
https://medium.com/rapids-ai/pytorch-rapids-rmm-maximize-the-memory-efficiency-of-your-workflows-f475107ba4d4
"""
import torch
from rmm.allocators.torch import rmm_torch_allocator
torch.cuda.memory.change_current_allocator(rmm_torch_allocator)
_set_torch_to_use_rmm()
client.run(_set_torch_to_use_rmm)
@quasiben , Wondering if we have opinion on this ? Happy to do a PR here to make life easier for folks like me.
CC: @alexbarghi-nv, @jnke2016 who have seen customer problems around a similar setup.
I have no objections to this, my only suggestion would be to make this a generic extensible option where we can then specify which libraries to set RMM as memory manager for, something like this:
--set-rmm-allocator=torch,another_future_library,...
Do you think that makes sense? @VibhuJawa if you want to get started on a PR for this I'm happy to help addressing any issues you may find along the way.