peft icon indicating copy to clipboard operation
peft copied to clipboard

GPU memory degradation despite offloading to CPU (on 50+ LoRAs for Stable Diffusion)

Open iuliaturc opened this issue 10 months ago • 22 comments

System Info

accelerate==0.28.0 diffusers==0.27.0 peft==0.9.0 transformers==4.38.2

Amazon EC2 instance (g5.2xlarge) Deep Learning OSS Nvidia Driver AMI GPU PyTorch 2.0.1 (Amazon Linux 2)

Who can help?

@pacman100 @younesbelkada

Information

  • [ ] The official example scripts
  • [X] My own modified scripts

Tasks

  • [X] An officially supported task in the examples folder
  • [ ] My own task or dataset (give details below)

Reproduction

Problem Setup

I am using the peft library to run SDXL + LoRA inference. We aim to support thousands of LoRAs. Here is the high-level algorithm that we are using on each query.

At model setup time, we load an SDXL model:

t2i_pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
    variant="fp16").to("cuda")

Then, for each query:

Step 1: Load the adapter into GPU, either from disk or CPU.

ensure_lora_is_on_gpu(request.adapter)
t2i_pipe.set_adapters(request.adapter)
assert len(t2i_pipe.get_active_adapters()) == 1

where ensure_lora_is_on_gpu is implemented like so:

def ensure_lora_is_on_gpu(lora_name: str):
    global t2i_pipe

    loaded_loras = t2i_pipe.get_list_adapters().get("unet", [])
    if lora_name in loaded_loras:
        t2i_pipe.set_lora_device([lora_name], "cuda")
    else:
        t2i_pipe.load_lora_weights(os.path.join(LORA_MODELS_PATH, lora_name),
                                   weight_name="lora.safetensors",
                                   local_files_only=True,
                                   adapter_name=lora_name)
        print(f"Successfully loaded LoRA {lora_name}. {len(loaded_loras) + 1} LoRAs are currently in memory.")

Step 2: Run inference

images = t2i_pipe(...)

Step 3: Offload the adapter to CPU

t2i_pipe.set_lora_device([request.adapter], "cpu")
t2i_pipe.disable_lora()

Problem Symptoms

To test our setup, we make ~60 sequential calls to the model, each with a different LoRA. In this log file, you can see the latency gradually degrading from ~20s/query to ~43s/query. cpu_offloading_sagemaker_calls.txt

Looking at the GPU memory utilization, you can see it gradually getting close to 100% (which must be the cause of degradation in latency): cpu_offloading_gpu_memory_utilization

This is surprising to me, since we are always offloading the LoRA onto CPU after running inference.

Debugging attempt:

Instead of offloading to CPU in Step 3, we unloaded after every inference:

Step 3: Completely unload LoRA

t2i_pipe.unload_lora_weights()

Repeating the same experiment as above (60 sequential requests with different LoRAs), we see that this time latency stays constant across calls: unload_loras_sagemaker_calls.txt

And GPU memory consumption stays constant, as expected: unload_loras_gpu_memory_utilization

Expected behavior

I would expect GPU memory utilization to stay constant when we offload LoRAs to the CPU (similarly to what happens when we unload them completely). It seems like our Step 3 (t2i_pipe.set_lora_device([request.adapter], "cpu")) has virtually no effect on GPU memory consumption.

Hypothesis 1:

We are missing something obvious.

Hypothesis 2 (debunked I think):

Maybe the GPU keeps a copy of the tensor (in addition to the one that is copied to the CPU)?

I tried modifying this line to explicitly delete the GPU tensor, but it did not help.

Before:

unet_module.lora_A[adapter_name].to(device)

After (doesn't help):

gpu_tensor = unet_module.lora_A[adapter_name]
gpu_tensor.to("cpu")
del gpu_tensor

Your help would be much appreciated!

iuliaturc avatar Apr 11 '24 00:04 iuliaturc

Thanks for reporting this. I'll also ping @sayakpaul since this involves diffusers code as well.

With the given information, it's hard to tell for me what the source of the issue could be. Would it be possible to share a full reproducer?

Let's first focus on the GPU usage: What looks strange to me is that in your first graph, it continually increases, but then suddenly there is a jump down. Did you do anything special to cause this? If not, my suspicion is that GPU memory is not immediately freed when you offload the LoRA adapters, and that there is garbage collection when PyTorch (?) detects the need for it. If you're not already doing it, could you check if manually triggering garbage collection helps:

gc.collect()
torch.cuda.empty_cache()

Since unload_lora_weights goes through a completely different code path, the garbage collection heuristic could be triggered there automatically, explaining the flat memory curve.

Looking at the GPU memory utilization, you can see it gradually getting close to 100% (which must be the cause of degradation in latency):

Honestly, I'm not sure if this really follows. If this was the case, I would expect the latency to move in parallel to the GPU utilization, when your data shows instead that it increases monotonically:

image

Instead of offloading to CPU in Step 3, we unloaded after every inference:

Step 3: Completely unload LoRA

t2i_pipe.unload_lora_weights()

Repeating the same experiment as above (60 sequential requests with different LoRAs), we see that this time latency stays constant across calls:

Does this method have any disadvantage or can you just switch over to it?

BenjaminBossan avatar Apr 11 '24 10:04 BenjaminBossan

Thanks @BenjaminBossan for looking into this.

Let's first focus on the GPU usage: What looks strange to me is that in your first graph, it continually increases, but then suddenly there is a jump down. Did you do anything special to cause this? If not, my suspicion is that GPU memory is not immediately freed when you offload the LoRA adapters, and that there is garbage collection when PyTorch (?) detects the need for it.

I'm not doing anything special to cause those jumps down. I also suspected it's garbage collection kicking in. I tried to address it with the explicit del gpu_tensor hack mentioned above (which didn't have any effect). However, your proposal (to call garbage collection explicitly after every call) did have an effect on GPU memory usage. It's still increasing, but up to 80% instead of 100%. And perhaps now it's a bit more linear and visibly correlated with the increase in latency. Screenshot 2024-04-11 at 11 26 16 AM

I still wouldn't expect the memory usage to grow linearly, but rather stay relatively flat (the way it happens when I completely unload the model).

Annoyingly, the latency still increases steadily: with-garbage-collection.txt

Honestly, I'm not sure if this really follows. If this was the case, I would expect the latency to move in parallel to the GPU utilization, when your data shows instead that it increases monotonically

I agree there's a chance they are independent. The alternative explanation I have is that somehow, by mistake, maybe by misusing the API, I am actually stacking LoRAs (so at the end the last call is basically running 60 LoRAs). I tried to make sure that's not the case via assert len(t2i_pipe.get_active_adapters()) == 1.

Does this method have any disadvantage or can you just switch over to it? For now unloading is a workable solution, but loading from disk takes ~3-4 seconds. If swapping between CPU and GPU worked as expected, I could reduce a 20s latency to 16s, which is quite significant.

If you don't have any other suggestions that I could try, I'll work on a minimal script to reproduce these measurements. It'd be a pity to not get to the bottom of this. Something's not quite right -- either the API is a bit confusing and allowing me to do the wrong thing, or there's some memory leak worth fixing.

Thanks so much for your patience to read through this, and generally for building this library.

iuliaturc avatar Apr 11 '24 18:04 iuliaturc

Here's a minimal script that reproduces the issue. Note that I'm setting num_inference_steps=1 for the sake of speed, so the absolute numbers will not be comparable to the ones above.

from diffusers import StableDiffusionXLPipeline
from huggingface_hub import hf_hub_download
import gc
import os
import shutil
import time
import torch

LOCAL_LORA_MODELS_PATH = "my-loras"
NUM_LORAS = 50


def ensure_lora_is_on_gpu(t2i_pipe, lora_name: str):
    loaded_loras = t2i_pipe.get_list_adapters().get("unet", [])
    if lora_name in loaded_loras:
        t2i_pipe.set_lora_device([lora_name], "cuda")
    else:
        t2i_pipe.load_lora_weights(os.path.join(LOCAL_LORA_MODELS_PATH, lora_name),
                                   weight_name="lora.safetensors",
                                   local_files_only=True,
                                   adapter_name=lora_name)
        print(f"Successfully loaded LoRA {lora_name}. {len(loaded_loras) + 1} LoRAs are currently in memory.")


def infer(t2i_pipe, prompt, adapter):
    start = time.time()

    ensure_lora_is_on_gpu(t2i_pipe, adapter)
    t2i_pipe.set_adapters(adapter)
    assert len(t2i_pipe.get_active_adapters()) == 1
    print("Active adapters:", t2i_pipe.get_active_adapters())

    images = t2i_pipe(
        prompt=prompt,
        num_images_per_prompt=1,
        # Normally this should be set to ~50, but here we don't care about quality.
        num_inference_steps=1).images

    t2i_pipe.set_lora_device([adapter], "cpu")
    t2i_pipe.disable_lora()
    gc.collect()
    torch.cuda.empty_cache()
    print(f"Offloaded LoRA {adapter} from GPU to CPU memory.")

    end = time.time()
    print(f"Inference took {end-start} seconds.")
    return images


if __name__ == "__main__":

    ### Step 0: Download one LoRA model and make copies of it (to simulate the case when we have many different LoRAs).
    if not os.path.exists(LOCAL_LORA_MODELS_PATH):
        hf_hub_download(repo_id="storia/sample-lora", filename="lora.safetensors", local_dir="sample-lora")
        for i in range(NUM_LORAS):
            lora_path = f"{LOCAL_LORA_MODELS_PATH}/lora{i}"
            shutil.copytree("sample-lora", lora_path)
    print(f"Set up {NUM_LORAS} LoRAS.")

    ### Step 1: Load the model
    t2i_pipe = StableDiffusionXLPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0",
        torch_dtype=torch.float16,
        variant="fp16").to("cuda")
    print(f"Loaded the base model.")

    ### Step 3: Run inference, once with each LoRA model
    for i in range(NUM_LORAS):
        print(f"Calling inference with LoRA {i}")
        infer(t2i_pipe, prompt="unicorn in a fishbowl", adapter=f"lora{i}")

Steps to repro:

  1. Run nvidia-smi --id=0 --loop=5 --query-gpu=timestamp,memory.used --format=csv > /tmp/gpu-memory.csv
  2. Run python repro.py
  3. Once the script is done, import /tmp/gpu-memory.csv into e.g. Google Sheets to plot the memory utilization.

With this minimal setup, I'm seeing:

  • Latency going up from ~3.4s for the first call to ~15s for the last call.
  • GPU memory utilization going steadily up (with some spikes I don't know how to explain)
Screenshot 2024-04-11 at 7 23 20 PM

iuliaturc avatar Apr 12 '24 02:04 iuliaturc

I think your benchmark timing code is a little flawed. I modified your code to:

  • Warm up the pipeline call to avoid drastic variance in the timing.
  • Moved the timing calculation so that we don't include anything except the pipeline call inside of it. Here's what it looks like no:
Code
from diffusers import StableDiffusionXLPipeline
from huggingface_hub import hf_hub_download
import gc
import os
import shutil
import time
import torch

NUM_WAMRUP_ITERS = 5
LOCAL_LORA_MODELS_PATH = "my-loras"
NUM_LORAS = 10


def ensure_lora_is_on_gpu(t2i_pipe, lora_name: str):
    loaded_loras = t2i_pipe.get_list_adapters().get("unet", [])
    if lora_name in loaded_loras:
        t2i_pipe.set_lora_device([lora_name], "cuda")
    else:
        t2i_pipe.load_lora_weights(os.path.join(LOCAL_LORA_MODELS_PATH, lora_name),
                                   weight_name="lora.safetensors",
                                   local_files_only=True,
                                   adapter_name=lora_name)
        print(f"Successfully loaded LoRA {lora_name}. {len(loaded_loras) + 1} LoRAs are currently in memory.")


def infer(t2i_pipe, prompt, adapter):
    ensure_lora_is_on_gpu(t2i_pipe, adapter)
    t2i_pipe.set_adapters(adapter)
    assert len(t2i_pipe.get_active_adapters()) == 1
    print("Active adapters:", t2i_pipe.get_active_adapters())
    
    start = time.time()
    images = t2i_pipe(
        prompt=prompt,
        num_images_per_prompt=1,
        # Normally this should be set to ~50, but here we don't care about quality.
        num_inference_steps=1
    ).images
    end = time.time()
    t2i_pipe.set_lora_device([adapter], "cpu")
    t2i_pipe.disable_lora()
    print(f"{t2i_pipe.get_active_adapters()=}")

    gc.collect()
    torch.cuda.empty_cache()
    
    print(f"Offloaded LoRA {adapter} from GPU to CPU memory.")
    print(f"Inference took {end-start} seconds.")
    return images


if __name__ == "__main__":

    ### Step 0: Download one LoRA model and make copies of it (to simulate the case when we have many different LoRAs).
    if not os.path.exists(LOCAL_LORA_MODELS_PATH):
        hf_hub_download(repo_id="storia/sample-lora", filename="lora.safetensors", local_dir="sample-lora")
        for i in range(NUM_LORAS):
            lora_path = f"{LOCAL_LORA_MODELS_PATH}/lora{i}"
            shutil.copytree("sample-lora", lora_path)
    print(f"Set up {NUM_LORAS} LoRAS.")

    ### Step 1: Load the model
    t2i_pipe = StableDiffusionXLPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0",
        torch_dtype=torch.float16,
        variant="fp16"
    ).to("cuda")
    print(f"Loaded the base model.")

    print("Running warmup iterations:")
    for _ in range(NUM_WAMRUP_ITERS):
        _ = t2i_pipe("okay", num_inference_steps=1).images[0]

    ### Step 3: Run inference, once with each LoRA model
    for i in range(NUM_LORAS):
        print(f"Calling inference with LoRA {i}")
        infer(t2i_pipe, prompt="unicorn in a fishbowl", adapter=f"lora{i}")

Now, I don't see spikes in the latency: log.txt

Here are some snippets from the log file:

...
Calling inference with LoRA 0
Successfully loaded LoRA lora0. 1 LoRAs are currently in memory.
Active adapters: ['lora0']
t2i_pipe.get_active_adapters()=['lora0']
Offloaded LoRA lora0 from GPU to CPU memory.
Inference took 2.2467219829559326 seconds.
Calling inference with LoRA 1
Successfully loaded LoRA lora1. 2 LoRAs are currently in memory.
Active adapters: ['lora1']
t2i_pipe.get_active_adapters()=['lora1']
Offloaded LoRA lora1 from GPU to CPU memory.
Inference took 2.277334213256836 seconds.
Calling inference with LoRA 2
Successfully loaded LoRA lora2. 3 LoRAs are currently in memory.
Active adapters: ['lora2']
t2i_pipe.get_active_adapters()=['lora2']
Offloaded LoRA lora2 from GPU to CPU memory.
Inference took 2.2735793590545654 seconds.
Calling inference with LoRA 3
Successfully loaded LoRA lora3. 4 LoRAs are currently in memory.
Active adapters: ['lora3']
t2i_pipe.get_active_adapters()=['lora3']
Offloaded LoRA lora3 from GPU to CPU memory.
Inference took 2.1650137901306152 seconds.
...

Going to investigate the GPU memory consumption now. If I find anything salient, will post here.

sayakpaul avatar Apr 12 '24 08:04 sayakpaul

I am seeing a very different memory consumption trend with my modified script (in the comment above):

image

(Y-Axis denotes memory consumption)

I also don't have any explanation for the sudden spikes. I think we could just log the total GPU memory consumption at strategic points in the script to debug things further.

sayakpaul avatar Apr 12 '24 10:04 sayakpaul

Thanks a lot, Sayak, for digging deeper into this.

BenjaminBossan avatar Apr 12 '24 11:04 BenjaminBossan

Hi there, just wanted to check in and see if there are any new developments

iuliaturc avatar Apr 17 '24 17:04 iuliaturc

Oh by

I also don't have any explanation for the sudden spikes. I think we could just log the total GPU memory consumption at strategic points in the script to debug things further.

I assumed you were going to look into it further (as I saw much less uneven spikes in my tests, as reported above). Sorry for not making it clear.

sayakpaul avatar Apr 18 '24 01:04 sayakpaul

Sorry, I didn't understand your plot above. Is it suggesting that the memory consumption is actually constant around 700? And what is the unit of the Y axis / memory consumption? MB?

What investigation on our side would be helpful?

iuliaturc avatar Apr 18 '24 21:04 iuliaturc

Sorry, I didn't understand your plot above. Is it suggesting that the memory consumption is actually constant around 700? And what is the unit of the Y axis / memory consumption? MB?

Yes, it's MB. However, I mentioned:

I am seeing a very different memory consumption trend with my modified script (in the comment above):

I also don't have any explanation for the sudden spikes.

What I am suggesting is that it's not as drastic as the ones you sent over.

What investigation on our side would be helpful?

This -

I also don't have any explanation for the sudden spikes. I think we could just log the total GPU memory consumption at strategic points in the script to debug things further.

Maybe we could log the total GPU memory consumption after loading the LoRAs, performing inference, etc. (i.e., strategic points in the benchmarking script as to where we think it might cause a spike). Let me know if this is still not clear.

sayakpaul avatar Apr 19 '24 01:04 sayakpaul

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

github-actions[bot] avatar May 13 '24 15:05 github-actions[bot]

I managed to trace where the memory leak is coming from. It's the pipe.load_lora_weights method.

As a reminder, the high-level algorithm here is:

for n in num_loras:
    load lora n
    run inference with lora n
    move lora n from gpu to cpu

The problem is that, when loading lora n onto the GPU (first line of loop), it also brings lora n-1 back to the GPU. And since we only move lora n to the CPU (third line of loop), lora n-1 lingers in the GPU and contributes to the memory increase we are seeing. I was able to tell this by printing all the weights that are currently on the GPU before and after the load method.

I don't yet know what goes wrong in pipe.load_lora_weights to be able to send a PR, but at this point I am very convinced that there is indeed a memory leak in the peft library.

iuliaturc avatar May 15 '24 17:05 iuliaturc

Here's a full repro in a Colab notebook (just added some print statements to the code snippet above): https://colab.research.google.com/drive/1u-DTQFZHGSiR-287CS3ELUYOC6jfhMqU?usp=sharing

You can see that, on each iteration, two LoRAs are loaded instead of one:

LoRAs loaded in this iteration: {'lora0', None}
...
LoRAs loaded in this iteration: {'lora0', 'lora1'}
...
LoRAs loaded in this iteration: {'lora2', 'lora1'}
...
LoRAs loaded in this iteration: {'lora2', 'lora3'}
...
LoRAs loaded in this iteration: {'lora3', 'lora4'}

From what I can tell, this target.update_layer is the culprit, but don't know exactly why it additionally loads the weights of the previous LoRA.

iuliaturc avatar May 15 '24 18:05 iuliaturc

Thanks very much for your detailed investigation. Will defer to @BenjaminBossan to comment further.

sayakpaul avatar May 16 '24 01:05 sayakpaul

Thanks a lot @iuliaturc for investigating further. I can reproduce the results you shared and the culprit is indeed update_layer, most notably these lines:

https://github.com/huggingface/peft/blob/e003ae78506f316fb0e6a97e6132cb4c590a47ab/src/peft/tuners/lora/layer.py#L124-L132

Here we don't take into consideration that the same LoRA layer could have weights on different devices. I'll investigate what we can do about it.

BenjaminBossan avatar May 16 '24 11:05 BenjaminBossan

Awesome, thanks a lot for looking into this.

iuliaturc avatar May 16 '24 22:05 iuliaturc

I created a PR, #1742, that should solve this issue. If you could test it @iuliaturc that would be awesome. At least with the provided notebook, it solves the issue for me, it would interesting to know if this indeed resolves your original problem.

BenjaminBossan avatar May 17 '24 13:05 BenjaminBossan

Thanks so much for the fix and sorry for the delayed response. I will try it out in the next day or two.

iuliaturc avatar May 21 '24 23:05 iuliaturc

Thanks again @BenjaminBossan for the PR!

I've rerun the notebook above with a higher number of MAX_LORAS.

The good news is that, indeed, I'm not seeing 2 LoRAs being loaded instead of one:

cat /tmp/logs.txt | grep "LoRAs loaded in this iteration"
LoRAs loaded in this iteration: {'lora0', None}
LoRAs loaded in this iteration: {'lora1'}
LoRAs loaded in this iteration: {'lora2'}
LoRAs loaded in this iteration: {'lora3'}
LoRAs loaded in this iteration: {'lora4'}
LoRAs loaded in this iteration: {'lora5'}
LoRAs loaded in this iteration: {'lora6'}
LoRAs loaded in this iteration: {'lora7'}
LoRAs loaded in this iteration: {'lora8'}
LoRAs loaded in this iteration: {'lora9'}
LoRAs loaded in this iteration: {'lora10'}
LoRAs loaded in this iteration: {'lora11'}
LoRAs loaded in this iteration: {'lora12'}
LoRAs loaded in this iteration: {'lora13'}
LoRAs loaded in this iteration: {'lora14'}
LoRAs loaded in this iteration: {'lora15'}
LoRAs loaded in this iteration: {'lora16'}
LoRAs loaded in this iteration: {'lora17'}
LoRAs loaded in this iteration: {'lora18'}
LoRAs loaded in this iteration: {'lora19'}
LoRAs loaded in this iteration: {'lora20'}
LoRAs loaded in this iteration: {'lora21'}
LoRAs loaded in this iteration: {'lora22'}
LoRAs loaded in this iteration: {'lora23'}
LoRAs loaded in this iteration: {'lora24'}

However, I'm still seeing the latency progressively creeping up:

cat /tmp/logs.txt | grep "Inference took"
Inference took 3.2485458850860596 seconds.
Inference took 2.660719156265259 seconds.
Inference took 2.897308826446533 seconds.
Inference took 3.131235122680664 seconds.
Inference took 3.343064546585083 seconds.
Inference took 3.602440357208252 seconds.
Inference took 3.8176300525665283 seconds.
Inference took 4.079342603683472 seconds.
Inference took 4.275134801864624 seconds.
Inference took 4.48743462562561 seconds.
Inference took 4.727515935897827 seconds.
Inference took 4.943339109420776 seconds.
Inference took 5.279767990112305 seconds.
Inference took 5.446081638336182 seconds.
Inference took 5.638545036315918 seconds.
Inference took 5.929015159606934 seconds.
Inference took 6.1448142528533936 seconds.
Inference took 6.351588010787964 seconds.
Inference took 6.608906984329224 seconds.
Inference took 6.842655181884766 seconds.
Inference took 7.075381755828857 seconds.
Inference took 7.31756854057312 seconds.
Inference took 7.547657251358032 seconds.
Inference took 7.814317941665649 seconds.
Inference took 8.001453161239624 seconds.
Inference took 8.338320970535278 seconds.
Inference took 8.44058632850647 seconds.
Inference took 8.678828954696655 seconds.
Inference took 8.94094467163086 seconds.
Inference took 9.13964033126831 seconds.
Inference took 9.445438623428345 seconds.
Inference took 9.641743183135986 seconds.
Inference took 9.83599042892456 seconds.
Inference took 10.089015007019043 seconds.
Inference took 10.279709815979004 seconds.

This doesn't invalidate your fix, but just means I'm back to square one trying to diagnose what is going on...

iuliaturc avatar May 22 '24 23:05 iuliaturc

With this fix, I'm seeing that memory isn't linearly going up anymore. There's a spike on each call, and overall memory consumption is uniform across calls. That's great! Screenshot 2024-05-22 at 5 13 12 PM

As mentioned before, latency goes steadily up: Screenshot 2024-05-22 at 5 13 55 PM

I also noticed that it/s (i.e., the number of denoising steps accomplished in a second) steadily goes down. So the overall latency increase is due to denoising steps taking increasingly longer: Screenshot 2024-05-22 at 5 14 59 PM

iuliaturc avatar May 23 '24 00:05 iuliaturc

Thanks for further looking into this. Is it possible for you to localize the space of the components that could be responsible for this?

sayakpaul avatar May 23 '24 08:05 sayakpaul

Thanks @iuliaturc for checking again. I merged the PR but will re-open this issue, as it's not fully resolved.

Btw with the latest PEFT version, we added some utilities that should facilitate this type of debugging in the future.

BenjaminBossan avatar May 23 '24 08:05 BenjaminBossan

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

github-actions[bot] avatar Jun 16 '24 15:06 github-actions[bot]