dask-cuda icon indicating copy to clipboard operation
dask-cuda copied to clipboard

Add cli option to enable pytorch to use same memory pool as rapids.

Open VibhuJawa opened this issue 1 year ago • 2 comments

Currently we need to below to set rmm to use pytorch pool on a dask-cuda cluster. We should do this via a cli

# Making PyTorch use the same memory pool as RAPIDS.
def _set_torch_to_use_rmm():
    """
    This function sets up the pytorch memory pool to be the same as the RAPIDS memory pool.
    This helps avoid OOM errors when using both pytorch and RAPIDS on the same GPU.
    See article:
    https://medium.com/rapids-ai/pytorch-rapids-rmm-maximize-the-memory-efficiency-of-your-workflows-f475107ba4d4
    """
    import torch
    from rmm.allocators.torch import rmm_torch_allocator
    torch.cuda.memory.change_current_allocator(rmm_torch_allocator)

_set_torch_to_use_rmm()
client.run(_set_torch_to_use_rmm)

VibhuJawa avatar Nov 17 '23 10:11 VibhuJawa

@quasiben , Wondering if we have opinion on this ? Happy to do a PR here to make life easier for folks like me.

CC: @alexbarghi-nv, @jnke2016 who have seen customer problems around a similar setup.

VibhuJawa avatar Feb 28 '24 17:02 VibhuJawa

I have no objections to this, my only suggestion would be to make this a generic extensible option where we can then specify which libraries to set RMM as memory manager for, something like this:

--set-rmm-allocator=torch,another_future_library,...

Do you think that makes sense? @VibhuJawa if you want to get started on a PR for this I'm happy to help addressing any issues you may find along the way.

pentschev avatar Feb 28 '24 18:02 pentschev