rmm icon indicating copy to clipboard operation
rmm copied to clipboard

[DOC]Need Best Practices of BinningMemoryResource

Open zjjott opened this issue 5 months ago • 3 comments

Report needed documentation

Report needed documentation Memory resource objects

Steps taken to search for needed documentation this doc haven't show how to use BinningMemoryResource to reach best performance describe by: https://developer.nvidia.com/blog/fast-flexible-allocation-for-cuda-with-rapids-memory-manager/#rmm_performance

I have a write benchmark scripts, according script python/rmm/rmm/tests/test_rmm.py, but I can't reach max performance and lowest memory usage, use CudaMemoryResource or ManagedMemoryResource? why PoolMemoryResource will OOM when only allocated 70GB? :

import torch
import random
from random import sample
import time
import tqdm
import numpy as np
from collections import Counter
def malloc_tensor(size,device):
    return torch.empty(size,device=device,dtype=torch.float32)
def use_allocator():
    from rmm.allocators.torch import rmm_torch_allocator
    import rmm
    # https://developer.nvidia.com/blog/fast-flexible-allocation-for-cuda-with-rapids-memory-manager/
    # using binning memory resource
    upstream = rmm.mr.CudaMemoryResource()
    upstream = rmm.mr.PoolMemoryResource(
            rmm.mr.ManagedMemoryResource(),
            1<<20,
            "80GiB",  # I found 80GB will OOM, 90GB will slow down
            #rmm.mr.c_percent_of_free_device_memory(90)
        )
    fixed_mr = rmm.mr.FixedSizeMemoryResource(upstream, 1 << 10)


    # Add fixed-size bins 256KiB, 512KiB, 1MiB, 2MiB, 4MiB
    mr = rmm.mr.BinningMemoryResource(upstream, 18, 22)
    mr.add_bin(1 << 10, fixed_mr)  # 1KiB bin
    cuda_mr = rmm.mr.CudaAsyncMemoryResource()
    #mr.add_bin(1 << 23, cuda_mr)  # 8MiB bin
    rmm.mr.set_current_device_resource(mr)
    print(rmm.mr.get_current_device_resource_type())
    torch.cuda.memory.change_current_allocator(rmm_torch_allocator)

def main():
    """
   size from 100M to 2000M random malloc,until reach 70G,then  release to 60G
    
    """
    use_allocator()
    import rmm
    rmm.statistics.enable_statistics()
    MB = 1024 * 1024
    GB = 1024 * 1024 * 1024
    total_size = 70 * GB
    current_size = 0
    tensors = []
    # warmup to 70GB

    while current_size < total_size:
        # float32: 4B
        size = (random.randint(100 * MB,2000 * MB))
        tensor = malloc_tensor(size//4, torch.cuda.current_device())
        tensors.append(tensor)
        current_size += size
    start_time = time.time()
    array = []
    malloc_counter = Counter()
    malloc_size = 0
    start_time = time.time()
    for i in tqdm.tqdm(range(500)):
        if i%10==9:

            delta_t = time.time() - start_time
            malloc_size = malloc_counter["malloc_size"]-malloc_size
            bandwidth = malloc_size / delta_t/GB
            print(f"malloc_size: {bandwidth}GB bandwidth",)
            array.append(bandwidth)

        random.shuffle(tensors)
        while(current_size > 60*GB and tensors):
            tensor_to_delete = tensors.pop()
            malloc_counter["free_size"]+=tensor_to_delete.numel()*4
            current_size -= (tensor_to_delete.numel()*4)
            del tensor_to_delete
            malloc_counter["free"]+=1
        while current_size < 70*GB:
            size = (random.randint(10 * MB,2000 * MB))
            tensor = malloc_tensor(size//4, torch.cuda.current_device())
            tensor.mul_(2.0)
            current_size += size
            tensors.append(tensor)
            malloc_counter["malloc"]+=1
            malloc_counter["malloc_size"]+=size
    array = np.asarray(array)
    print(f"mean: {array.mean()}ms, std: {array.std()}ms max:{array.max()}ms")
    import rmm
    report = rmm.statistics.default_profiler_records.report()
    print(report)
if __name__ == "__main__":
    main()

zjjott avatar Jun 13 '25 05:06 zjjott

I have test some cases on H800 without rmm,not adduse_allocator(), pytorch+cache allocator: mean 55.5GB/s std: 40GB/s max:325.00GB/s with rmm: PYTORCH_NO_CUDA_MEMORY_CACHING=1 python test_malloc.py

  1. default: CudaMemoryResource() mean: 100.71GB/s std: 9.54GB/s max:143.65GB/s
  2. CudaAsyncMemoryResource() 7iters OOM
  3. BinningMemoryResource(PoolMemoryResource(CudaMemoryResource(), 78GB)) 4iters OOM
  4. BinningMemoryResource(PoolMemoryResource(ManagedMemoryResource(), 100GB))

zjjott avatar Jun 14 '25 04:06 zjjott

@bdice Can you give some advice?

zjjott avatar Jun 18 '25 05:06 zjjott

Typically a BinningMemoryResource is used for making small allocations and large allocations with different MRs. For example, some applications might want to make small allocations in pinned memory that can be quickly moved between host and device, while large allocations should be in pageable memory.

It looks like you have all "large" allocations, ranging from 100 MB to 2 GB. I would not recommend a binning memory resource here, you should probably just pick one memory resource to use. Try CudaAsyncMemoryResource() and a pool initialized with 80% of the free memory, like this:

free_memory, _ = rmm.mr.available_device_memory()
free_memory = int(round(float(free_memory) * 0.80 / 256) * 256)

mr = rmm.mr.PoolMemoryResource(rmm.mr.ManagedMemoryResource(), initial_pool_size=free_memory)

bdice avatar Jun 18 '25 15:06 bdice