ray icon indicating copy to clipboard operation
ray copied to clipboard

A notable difference bewteen tasks and actors when it comes to GPU memory deallocation with PyTorch

Open erezinman opened this issue 2 years ago • 8 comments

Hi, I came across an annoying bug when using tasks with PyTorch. It seems that when a function is executed within an actor, the allocated memory deallocates properly, but when executing a method as a task, the GPU's memory is not cleared.

While the whole code is given below, the difference in calling as a task/actor is the following code:


if not USE_TASK:
    GPUAllocationTester = ray.remote(num_gpus=1)(GPUAllocationTester)

remaining = []
for i in range(N):
    if USE_TASK:
        run = ray.remote(num_gpus=1)(GPUAllocationTester.run).remote(i)
        remaining.append(run)
    else:
        allocator = GPUAllocationTester.remote()
        remaining.append(allocator.run.remote(i))

ray.wait(remaining, num_returns=len(remaining))

Full code

Consider the following code (place it in a file) and run it.
import sys
import os
import time
import traceback

import pandas as pd
import ray
import torch


def get_exception_string(exception: Exception) -> str:
    if hasattr(exception, '__traceback__'):
        exc_lines = traceback.format_exception(type(exception), exception, exception.__traceback__)
    else:
        exc_lines = traceback.format_exception_only(type(exception), exception)

    return str.join('', exc_lines)


class GPUAllocationTester:

    @staticmethod
    def _print(*args, i, **kwargs):
        print(pd.Timestamp.now(), '|', i, '-', *args, **kwargs)

    @staticmethod
    def _print_gpu_memory_usage(phase, i):
        free, total = torch.cuda.mem_get_info()
        gb = 1000000000
        GPUAllocationTester._print(
            f'{phase} - Free GPU #{os.environ["CUDA_VISIBLE_DEVICES"]} memory: '
            f'{free / gb :.2f}GB/{total / gb :.2f}GB ({free / total * 100:.2f}%)', i=i)

    @staticmethod
    def run(i):
        try:
            GPUAllocationTester._print_gpu_memory_usage('START', i=i)
            GPUAllocationTester._run_internal()
            GPUAllocationTester._print_gpu_memory_usage('END', i=i)
        except BaseException as ex:
            GPUAllocationTester._print(get_exception_string(ex), file=sys.stderr, i=i)
            GPUAllocationTester._print_gpu_memory_usage('ERROR', i=i)
            raise

    @staticmethod
    def _run_internal():
        empty = torch.empty(7, 1000, 1000, 1000, dtype=torch.int8, device='cuda')
        empty[:] = 0
        time.sleep(1)


if __name__ == '__main__':

    USE_TASK = True
    N = 50

    if not USE_TASK:
        GPUAllocationTester = ray.remote(num_gpus=1)(GPUAllocationTester)

    remaining = []
    for i in range(N):

        if USE_TASK:
            run = ray.remote(num_gpus=1)(GPUAllocationTester.run).remote(i)
            remaining.append(run)
        else:
            allocator = GPUAllocationTester.remote()
            remaining.append(allocator.run.remote(i))

    ray.wait(remaining, num_returns=len(remaining))

To see the difference, under if __name__ == '__main__':, set USE_TASK = True for tasks, or USE_TASK = False for actor. The operation that's performed is simply allocating a 7GB tensor on a GPU.

I'm using a 3xNVidia GTX-1080Ti machine running on Ubuntu 20.4. Using latest ray and torch versions (ray==2.1.0, torch==1.13.0). Also, I'm using NVidia driver v510.85.02 and CUDA v11.6.

Logs:

USE_TASK = True log:

(killed it after a few errors)

2022-11-17 16:40:56,703	INFO worker.py:1528 -- Started a local Ray instance.
(pid=3176528) 
(pid=3176530) 
(pid=3176539) 
(run pid=3177312) 2022-11-17 16:41:00.331670 | 2 - START - Free GPU #2 memory: 11.57GB/11.72GB (98.67%)
(run pid=3177310) 2022-11-17 16:41:00.375058 | 0 - START - Free GPU #0 memory: 11.04GB/11.71GB (94.24%)
(run pid=3177311) 2022-11-17 16:41:00.344489 | 1 - START - Free GPU #1 memory: 11.57GB/11.72GB (98.67%)
(run pid=3177312) 2022-11-17 16:41:02.492959 | 2 - END - Free GPU #2 memory: 4.18GB/11.72GB (35.68%)
(run pid=3177310) 2022-11-17 16:41:02.520987 | 0 - END - Free GPU #0 memory: 3.66GB/11.71GB (31.25%)
(run pid=3177311) 2022-11-17 16:41:02.449458 | 1 - END - Free GPU #1 memory: 4.18GB/11.72GB (35.68%)
(pid=3176529) 
(pid=3176540) 
(pid=3176536) 
(run pid=3177417) 2022-11-17 16:41:04.250240 | 39 - Traceback (most recent call last):
(run pid=3177417)   File "/home/razor/PycharmProjects/ood/testestsetsetse.py", line 38, in run
(run pid=3177417)     GPUAllocationTester._run_internal()
(run pid=3177417)   File "/home/razor/PycharmProjects/ood/testestsetsetse.py", line 47, in _run_internal
(run pid=3177417)     empty = torch.empty(7, 1000, 1000, 1000, dtype=torch.int8, device='cuda')
(run pid=3177417) torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 6.52 GiB (GPU 0; 10.92 GiB total capacity; 0 bytes already allocated; 3.76 GiB free; 0 bytes reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
(run pid=3177417) 
(run pid=3177417) 2022-11-17 16:41:04.228107 | 39 - START - Free GPU #2 memory: 4.04GB/11.72GB (34.47%)
(run pid=3177417) 2022-11-17 16:41:04.263327 | 39 - ERROR - Free GPU #2 memory: 4.04GB/11.72GB (34.47%)
(run pid=3177416) 2022-11-17 16:41:04.343015 | 8 - Traceback (most recent call last):
(run pid=3177416)   File "/home/razor/PycharmProjects/ood/testestsetsetse.py", line 38, in run
(run pid=3177416)     GPUAllocationTester._run_internal()
(run pid=3177416)   File "/home/razor/PycharmProjects/ood/testestsetsetse.py", line 47, in _run_internal
(run pid=3177416)     empty = torch.empty(7, 1000, 1000, 1000, dtype=torch.int8, device='cuda')
(run pid=3177416) torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 6.52 GiB (GPU 0; 10.92 GiB total capacity; 0 bytes already allocated; 3.76 GiB free; 0 bytes reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
(run pid=3177416) 
(run pid=3177418) 2022-11-17 16:41:04.332483 | 20 - Traceback (most recent call last):
(run pid=3177418)   File "/home/razor/PycharmProjects/ood/testestsetsetse.py", line 38, in run
(run pid=3177418)     GPUAllocationTester._run_internal()
(run pid=3177418)   File "/home/razor/PycharmProjects/ood/testestsetsetse.py", line 47, in _run_internal
(run pid=3177418)     empty = torch.empty(7, 1000, 1000, 1000, dtype=torch.int8, device='cuda')
(run pid=3177418) torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 6.52 GiB (GPU 0; 10.91 GiB total capacity; 0 bytes already allocated; 3.28 GiB free; 0 bytes reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
(run pid=3177418) 
(run pid=3177416) 2022-11-17 16:41:04.331889 | 8 - START - Free GPU #1 memory: 4.04GB/11.72GB (34.47%)
(run pid=3177416) 2022-11-17 16:41:04.343163 | 8 - ERROR - Free GPU #1 memory: 4.04GB/11.72GB (34.47%)
(run pid=3177418) 2022-11-17 16:41:04.323478 | 20 - START - Free GPU #0 memory: 3.52GB/11.71GB (30.04%)
(run pid=3177418) 2022-11-17 16:41:04.332858 | 20 - ERROR - Free GPU #0 memory: 3.52GB/11.71GB (30.04%)
(pid=3176545) 
(pid=3176541) 
(pid=3176526) 
(run pid=3177533) 2022-11-17 16:41:06.093183 | 15 - START - Free GPU #0 memory: 3.38GB/11.71GB (28.83%)
(run pid=3177535) 2022-11-17 16:41:06.096234 | 46 - START - Free GPU #1 memory: 3.90GB/11.72GB (33.26%)
(run pid=3177527) 2022-11-17 16:41:06.098243 | 44 - START - Free GPU #2 memory: 3.90GB/11.72GB (33.26%)
(run pid=3177533) 2022-11-17 16:41:06.104002 | 15 - ERROR - Free GPU #0 memory: 3.38GB/11.71GB (28.83%)
(run pid=3177535) 2022-11-17 16:41:06.105713 | 46 - ERROR - Free GPU #1 memory: 3.90GB/11.72GB (33.26%)
(run pid=3177527) 2022-11-17 16:41:06.107896 | 44 - ERROR - Free GPU #2 memory: 3.90GB/11.72GB (33.26%)
(run pid=3177533) 2022-11-17 16:41:06.103863 | 15 - Traceback (most recent call last):
(run pid=3177533)   File "/home/razor/PycharmProjects/ood/testestsetsetse.py", line 38, in run
(run pid=3177533)     GPUAllocationTester._run_internal()
(run pid=3177533)   File "/home/razor/PycharmProjects/ood/testestsetsetse.py", line 47, in _run_internal
(run pid=3177533)     empty = torch.empty(7, 1000, 1000, 1000, dtype=torch.int8, device='cuda')
(run pid=3177533) torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 6.52 GiB (GPU 0; 10.91 GiB total capacity; 0 bytes already allocated; 3.15 GiB free; 0 bytes reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
(run pid=3177533) 
(run pid=3177535) 2022-11-17 16:41:06.105586 | 46 - Traceback (most recent call last):
(run pid=3177535)   File "/home/razor/PycharmProjects/ood/testestsetsetse.py", line 38, in run
(run pid=3177535)     GPUAllocationTester._run_internal()
(run pid=3177535)   File "/home/razor/PycharmProjects/ood/testestsetsetse.py", line 47, in _run_internal
(run pid=3177535)     empty = torch.empty(7, 1000, 1000, 1000, dtype=torch.int8, device='cuda')
(run pid=3177535) torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 6.52 GiB (GPU 0; 10.92 GiB total capacity; 0 bytes already allocated; 3.63 GiB free; 0 bytes reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
(run pid=3177535) 
(run pid=3177527) 2022-11-17 16:41:06.107761 | 44 - Traceback (most recent call last):
(run pid=3177527)   File "/home/razor/PycharmProjects/ood/testestsetsetse.py", line 38, in run
(run pid=3177527)     GPUAllocationTester._run_internal()
(run pid=3177527)   File "/home/razor/PycharmProjects/ood/testestsetsetse.py", line 47, in _run_internal
(run pid=3177527)     empty = torch.empty(7, 1000, 1000, 1000, dtype=torch.int8, device='cuda')
(run pid=3177527) torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 6.52 GiB (GPU 0; 10.92 GiB total capacity; 0 bytes already allocated; 3.63 GiB free; 0 bytes reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
(run pid=3177527) 

Process finished with exit code 137 (interrupted by signal 9: SIGKILL)

USE_TASK = False log:

(Killed it after a few successful run)

2022-11-17 16:52:33,912	INFO worker.py:1528 -- Started a local Ray instance.
(pid=3179597) 
(pid=3179617) 
(pid=3179615) 
(GPUAllocationTester pid=3180377) 2022-11-17 16:52:36.981583 | 0 - START - Free GPU #0 memory: 11.01GB/11.71GB (94.03%)
(GPUAllocationTester pid=3180381) 2022-11-17 16:52:37.191638 | 1 - START - Free GPU #1 memory: 11.57GB/11.72GB (98.67%)
(GPUAllocationTester pid=3180382) 2022-11-17 16:52:37.176774 | 2 - START - Free GPU #2 memory: 11.57GB/11.72GB (98.67%)
(GPUAllocationTester pid=3180377) 2022-11-17 16:52:39.221093 | 0 - END - Free GPU #0 memory: 3.64GB/11.71GB (31.04%)
(GPUAllocationTester pid=3180381) 2022-11-17 16:52:39.303353 | 1 - END - Free GPU #1 memory: 4.18GB/11.72GB (35.68%)
(GPUAllocationTester pid=3180382) 2022-11-17 16:52:39.301855 | 2 - END - Free GPU #2 memory: 4.18GB/11.72GB (35.68%)
(GPUAllocationTester pid=3180487) 2022-11-17 16:52:40.893257 | 3 - START - Free GPU #0 memory: 11.02GB/11.71GB (94.08%)
(GPUAllocationTester pid=3180490) 2022-11-17 16:52:40.921452 | 5 - START - Free GPU #1 memory: 11.57GB/11.72GB (98.67%)
(GPUAllocationTester pid=3180489) 2022-11-17 16:52:41.067052 | 4 - START - Free GPU #2 memory: 11.57GB/11.72GB (98.67%)
(GPUAllocationTester pid=3180487) 2022-11-17 16:52:43.125802 | 3 - END - Free GPU #0 memory: 3.63GB/11.71GB (30.99%)
(GPUAllocationTester pid=3180490) 2022-11-17 16:52:43.119907 | 5 - END - Free GPU #1 memory: 4.18GB/11.72GB (35.68%)
(GPUAllocationTester pid=3180489) 2022-11-17 16:52:43.148156 | 4 - END - Free GPU #2 memory: 4.18GB/11.72GB (35.68%)
(GPUAllocationTester pid=3180588) 2022-11-17 16:52:44.824279 | 6 - START - Free GPU #1 memory: 11.57GB/11.72GB (98.67%)
(GPUAllocationTester pid=3180589) 2022-11-17 16:52:44.791533 | 7 - START - Free GPU #0 memory: 11.03GB/11.71GB (94.16%)
(GPUAllocationTester pid=3180590) 2022-11-17 16:52:44.838491 | 8 - START - Free GPU #2 memory: 11.57GB/11.72GB (98.67%)
(GPUAllocationTester pid=3180589) 2022-11-17 16:52:46.904957 | 7 - END - Free GPU #0 memory: 3.64GB/11.71GB (31.11%)
(GPUAllocationTester pid=3180590) 2022-11-17 16:52:46.884955 | 8 - END - Free GPU #2 memory: 4.18GB/11.72GB (35.68%)
(GPUAllocationTester pid=3180588) 2022-11-17 16:52:46.920967 | 6 - END - Free GPU #1 memory: 4.18GB/11.72GB (35.68%)
(GPUAllocationTester pid=3180692) 2022-11-17 16:52:48.566322 | 11 - START - Free GPU #1 memory: 11.57GB/11.72GB (98.67%)
(GPUAllocationTester pid=3180687) 2022-11-17 16:52:48.590687 | 9 - START - Free GPU #2 memory: 11.57GB/11.72GB (98.67%)
(GPUAllocationTester pid=3180688) 2022-11-17 16:52:48.590542 | 10 - START - Free GPU #0 memory: 11.02GB/11.71GB (94.10%)
(GPUAllocationTester pid=3180688) 2022-11-17 16:52:50.689763 | 10 - END - Free GPU #0 memory: 3.64GB/11.71GB (31.11%)
(GPUAllocationTester pid=3180692) 2022-11-17 16:52:50.830381 | 11 - END - Free GPU #1 memory: 4.18GB/11.72GB (35.68%)
(GPUAllocationTester pid=3180687) 2022-11-17 16:52:50.824641 | 9 - END - Free GPU #2 memory: 4.18GB/11.72GB (35.68%)
(GPUAllocationTester pid=3180792) 2022-11-17 16:52:52.359383 | 12 - START - Free GPU #0 memory: 11.03GB/11.71GB (94.16%)
(GPUAllocationTester pid=3180795) 2022-11-17 16:52:52.510010 | 13 - START - Free GPU #2 memory: 11.57GB/11.72GB (98.67%)
(GPUAllocationTester pid=3180796) 2022-11-17 16:52:52.520778 | 14 - START - Free GPU #1 memory: 11.57GB/11.72GB (98.67%)
(GPUAllocationTester pid=3180792) 2022-11-17 16:52:54.524954 | 12 - END - Free GPU #0 memory: 3.64GB/11.71GB (31.06%)
(GPUAllocationTester pid=3180795) 2022-11-17 16:52:54.600953 | 13 - END - Free GPU #2 memory: 4.18GB/11.72GB (35.68%)
(GPUAllocationTester pid=3180796) 2022-11-17 16:52:54.576958 | 14 - END - Free GPU #1 memory: 4.18GB/11.72GB (35.68%)
Traceback (most recent call last):
  File "/home/razor/PycharmProjects/ood/testestsetsetse.py", line 70, in <module>
    ray.wait(remaining, num_returns=len(remaining))
  File "/home/razor/anaconda3/envs/dev/lib/python3.9/site-packages/ray/_private/client_mode_hook.py", line 105, in wrapper
    return func(*args, **kwargs)
  File "/home/razor/anaconda3/envs/dev/lib/python3.9/site-packages/ray/_private/worker.py", line 2476, in wait
    ready_ids, remaining_ids = worker.core_worker.wait(
  File "python/ray/_raylet.pyx", line 1650, in ray._raylet.CoreWorker.wait
  File "python/ray/_raylet.pyx", line 190, in ray._raylet.check_status
KeyboardInterrupt

Process finished with exit code 137 (interrupted by signal 9: SIGKILL)

EDIT: I now see that there's a new process opened for every worker I use, and that's fine with me so long as memory is cleared. However, this is still an unwanted behavior.

erezinman avatar Nov 17 '22 14:11 erezinman

Hey @erezinman , I think this behavior is fixed in Ray nightly, could you verify?

This was the tracking issue: https://github.com/ray-project/ray/issues/29624

After the patch, Ray will destroy the task worker for GPU tasks by default after it runs, preventing such leaks.

ericl avatar Nov 17 '22 21:11 ericl

I didn't try if that's fixed by it, but I presume that it would. I noticed that even using a "dummy actor" fixed it. "dummy actor":

@ray.remote(num_gpus=1)
class Runner:
    def run(self, f, *args):
        return f(*args)

erezinman avatar Nov 18 '22 18:11 erezinman

Yup, that's the same issue then. For actors, we always destroy their worker processes after they are killed, so GPU memory leaks aren't an issue. Previously, for GPU tasks we re-used the workers (this was the default for CPU tasks), but that change defaults GPU tasks to have @ray.remote(max_calls=1) by default, also destroying their worker processes after the task finishes.

ericl avatar Nov 18 '22 18:11 ericl

@ericl I am running version 2.2 and when using ray.tune it seems the worker processes on the worker nodes are not killed, and memory is not deallocated. I.e. when i stop the main script, and restart with "resume=True", the new resumed run on the worker nodes reuses the same processes and GPU is fully allocated, and I get memory errors... this did not happen with an older version like 1.10 I think.

thoglu avatar Jan 01 '23 19:01 thoglu

Hmm, that shouldn't happen across runs. Do you have a reproduction script? When I try to run / cancel / resume a job locally, I see trials using new PIDs as expected:

from ray import tune
import ray
import os

NUM_MODELS = 100

def train_model(config):
    print("Training model")
    import time
    time.sleep(100)
    return {"score": score, "other_data": ...}

# Can customize resources per trial, here we set 1 CPU each.
train_model = tune.with_resources(train_model, {"cpu": 1})

tuner = tune.Tuner.restore(
    path="~/ray_results/train_model_2023-01-03_12-38-05"
)
#tuner = tune.Tuner(train_model)
results = tuner.fit()
+-------------------------+----------+-----------------------+
| Trial name              | status   | loc                   |
|-------------------------+----------+-----------------------|
| train_model_8663c_00000 | RUNNING  | 10.103.244.198:148975 |
+-------------------------+----------+-----------------------+
+-------------------------+----------+-----------------------+
| Trial name              | status   | loc                   |
|-------------------------+----------+-----------------------|
| train_model_8663c_00000 | RUNNING  | 10.103.244.198:149167 |
+-------------------------+----------+-----------------------+

Also, make sure you don't have reuse_actors set to True in the Tune config, though this should only affect trials within the same Tune run.

ericl avatar Jan 03 '23 20:01 ericl

I have written a blog post on the forums (https://discuss.ray.io/t/gpu-memory-is-not-freed-on-cluster-after-ctrl-c-can-i-respond-to-specific-errors-from-within-a-client-node/8835) that is related. Maybe the issue is how I set up my cluster? I do not run locally. I have the head node and start ray via the CLI (ray start -- options) and on the worker nodes also with the CLI (ray start -- ... ). On the head node, I then run the python script , ray.init() and start the tuner similar to your script. Then ctrl+x, resume, and in the logs I see the same PID again and again causing memory errors .. I use something like torch.cuda.list_gpu_processes and it is always the same ID and memory in the card at the beginning of each resume. I have set reuse_actors=False explicitly, but that does not change it.

thoglu avatar Jan 04 '23 01:01 thoglu

I see. It's possible the GPU processes are leaked. Do you see those processes hanging out in the cluster after the ctrl+x (c)? but before the second run? Do they go away after waiting for a minute?

Btw, it would be helpful to report the output of ray status during the first run, and then after the ctrl+c of the first run.

ericl avatar Jan 04 '23 02:01 ericl

I started it again.

ray status

yields

======== Autoscaler status: 2023-01-04 14:18:31.968221 ========
Node status
---------------------------------------------------------------
Healthy:
 1 node_90a4e5e792e7a584e78fbdcd53f7f2abe2350e653f0d2b219265c892
 1 node_a9a07f99061855c976a9f37194d3c8fb847c0364bf204e2ac9e07b37
 1 node_6194a7ac8d34bdccf408ddb625ddc997488aeffcf370b8307437d595
 1 node_46794be6eba5371181ed88a3956006129912129033bba23edb1b88d8
 1 node_5d967084754cc90b3bc40f5f11b1d7287f3779b458585244462b24cc
Pending:
 (no pending nodes)
Recent failures:
 (no failures)

Resources
---------------------------------------------------------------
Usage:
 16.0/20.0 CPU (16.0 used of 16.0 reserved in placement groups)
 4.0/4.0 GPU (4.0 used of 4.0 reserved in placement groups)
 0.0/1.0 accelerator_type:A100
 0.0/4.0 accelerator_type:G
 0.00/1170.408 GiB memory
 0.00/505.595 GiB object_store_memory

Demands:
 {'GPU': 1.0, 'CPU': 4.0} * 1 (PACK): 16+ pending placement groups

Now I do ctrl+x, and wait a few minutes.

======== Autoscaler status: 2023-01-04 14:36:26.906442 ========
Node status
---------------------------------------------------------------
Healthy:
 1 node_90a4e5e792e7a584e78fbdcd53f7f2abe2350e653f0d2b219265c892
 1 node_a9a07f99061855c976a9f37194d3c8fb847c0364bf204e2ac9e07b37
 1 node_6194a7ac8d34bdccf408ddb625ddc997488aeffcf370b8307437d595
 1 node_46794be6eba5371181ed88a3956006129912129033bba23edb1b88d8
 1 node_5d967084754cc90b3bc40f5f11b1d7287f3779b458585244462b24cc
Pending:
 (no pending nodes)
Recent failures:
 (no failures)

Resources
---------------------------------------------------------------
Usage:
 0.0/20.0 CPU
 0.0/4.0 GPU
 0.0/1.0 accelerator_type:A100
 0.0/4.0 accelerator_type:G
 0.00/1170.408 GiB memory
 0.00/505.595 GiB object_store_memory

Demands:
 (no resource demands)

I restart the job, and get CUDA errors.. ray status yields again (while running with errors):

======== Autoscaler status: 2023-01-04 14:38:07.465586 ========
Node status
---------------------------------------------------------------
Healthy:
 1 node_90a4e5e792e7a584e78fbdcd53f7f2abe2350e653f0d2b219265c892
 1 node_a9a07f99061855c976a9f37194d3c8fb847c0364bf204e2ac9e07b37
 1 node_6194a7ac8d34bdccf408ddb625ddc997488aeffcf370b8307437d595
 1 node_46794be6eba5371181ed88a3956006129912129033bba23edb1b88d8
 1 node_5d967084754cc90b3bc40f5f11b1d7287f3779b458585244462b24cc
Pending:
 (no pending nodes)
Recent failures:
 (no failures)

Resources
---------------------------------------------------------------
Usage:
 16.0/20.0 CPU (16.0 used of 16.0 reserved in placement groups)
 4.0/4.0 GPU (4.0 used of 4.0 reserved in placement groups)
 0.0/1.0 accelerator_type:A100
 0.0/4.0 accelerator_type:G
 0.00/1170.408 GiB memory
 0.01/505.595 GiB object_store_memory

Demands:
 {'GPU': 1.0, 'CPU': 4.0} * 1 (PACK): 16+ pending placement groups

Note that there are 5 nodes (head node and 4 worker nodes, but only the worker nodes have --num_gpus=1, while the head node uses --num_gpus=0 .. I do not want to run on the head node, even though it has a graphics card)

thoglu avatar Jan 04 '23 13:01 thoglu

So it looks like Ray is releasing the resources / processes on its side.

The processes might somehow still be stuck on shutdown though for some other reason.

Could you also provide the output of ps aux | grep ray on each node before and after the initial ctrl-c? (this might be easier to run on a 1-GPU node only cluster)

ericl avatar Jan 04 '23 19:01 ericl

can you say what you are looking for? The output string is pretty long and I dont want to post it here necessarily.

thoglu avatar Jan 04 '23 21:01 thoglu

I'm particularly wondering if those processes using the GPU are still remaining in the ps aux list after "ray status" is reporting all GPUs have been released, and also the process state.

Perhaps you can post just the row for those original actor processes before/after stopping the job?

ericl avatar Jan 04 '23 21:01 ericl

Yes, so I see a bunch of ImplicitFunc:Train and ImplicitFunc:IDLE processes, and those actually do not get killed by ctrl+x.. in particular I get more and more of those remaining after a bunch of ctrl+x and resumes.

230106  9.2  1.4 138192236 5558324 pts/6 SNl 22:47   0:14 ray::ImplicitFunc.train
230162  9.2  1.4 138192248 5562576 pts/6 SNl 22:47   0:14 ray::ImplicitFunc.train
.... etc

Furthermore, a call to nvidia-smi shows there is atually still reserved memory on one graphics card (in the upper graphics card), even though no train run is actively running .. I think that comes from those train/IDLE ray processes.

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.161.03   Driver Version: 470.161.03   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ...  Off  | 00000000:3B:00.0 Off |                  N/A |
| 29%   37C    P2    71W / 250W |   3845MiB / 11019MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce ...  Off  | 00000000:5E:00.0 Off |                  N/A |
| 20%   22C    P8     9W / 250W |      0MiB / 11178MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
+-----------------------------------------------------------------------------+

When I start yet another run with resume, nvidia-smi looks like this:

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.161.03   Driver Version: 470.161.03   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ...  Off  | 00000000:3B:00.0 Off |                  N/A |
| 29%   38C    P2    71W / 250W |   4694MiB / 11019MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce ...  Off  | 00000000:5E:00.0 Off |                  N/A |
| 20%   22C    P8     9W / 250W |      0MiB / 11178MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A    231255      C   ray::ImplicitFunc.train           849MiB |
+-----------------------------------------------------------------------------+

Notice how below under processes there is a new implicitFun.train process with 849 MB, while in the upper graphics card much more is actually reserved (from the other previous ray processes I think). After killing all ray process with ray stop, all this graphics memory is released, but a simple ctrl+x to kill the current tune run does not do it.

thoglu avatar Jan 04 '23 21:01 thoglu

Ok, so it seems Ray isn't properly terminating those processes (without a manual cluster stop). The only challenge is figuring out why then.

It would be helpful to figure out a minimal repro script. I have not been able to cause the same leak in my cluster unfortunately. Does running something like the following (ctrl-c during the sleep) cause the process leak for you? Wondering if there's some specific thing that triggers it.

import ray
import time
import torch

@ray.remote(num_gpus=1)
class Actor:
    def init_gpu(self):
        print("Claiming GPU")
        torch.tensor(1000000).cuda()

a = Actor.remote()
a.init_gpu.remote()
time.sleep(60)
print("exiting normally")

ericl avatar Jan 04 '23 22:01 ericl

Btw, it may be helpful to file a new issue, then we can bump the priority--- this bug is a bit unrelated to the original issue topic.

ericl avatar Jan 04 '23 22:01 ericl

Ok I started #31451. By the way I tried to run your above script but I get the error:

2023-01-05 00:14:22,528	INFO worker.py:1529 -- Started a local Ray instance. View the dashboard at 127.0.0.1:8266 
[2023-01-05 00:14:52,747 E 236279 236279] core_worker.cc:179: Failed to register worker 01000000ffffffffffffffffffffffffffffffffffffffffffffffff to Raylet. IOError: [RayletClient] Unable to register worker with raylet. No such file or directory

Why does it even start a new local ray instance? Should the script not connect to the main ray process started with ray start prior to running the script? I have no experience with ray core, I usually just use tune.

thoglu avatar Jan 04 '23 23:01 thoglu

Thanks!

It should be connecting automatically to the latest Ray cluster (at least of recent Ray versions). You can force a connection by adding ray.init("auto") to the start of the script.

The script is emulating launching a single actor tune trial that uses a GPU (Tune is just launching actors under the hood in the same way).

ericl avatar Jan 04 '23 23:01 ericl

Going to close this issue so we can track in https://github.com/ray-project/ray/issues/31451

ericl avatar Jan 04 '23 23:01 ericl