diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

possibly to avoid `from_single_file` loading in fp32 to save RAM

Open asomoza opened this issue 10 months ago • 14 comments

Describe the bug

When loading a model using from_single_file(), the RAM usage is really high possibly because the weights are loaded in FP32 before conversion.

Reproduction

import threading
import time

import psutil
import torch
from huggingface_hub import hf_hub_download

from diffusers import UNet2DConditionModel


filename = hf_hub_download("stable-diffusion-v1-5/stable-diffusion-v1-5", filename="v1-5-pruned-emaonly.safetensors")

stop_monitoring = False


def log_memory_usage():
    process = psutil.Process()
    mem_info = process.memory_info()
    return mem_info.rss / (1024**2)  # Convert to MB


def monitor_memory(interval, peak_memory):
    while not stop_monitoring:
        current_memory = log_memory_usage()
        peak_memory[0] = max(peak_memory[0], current_memory)
        time.sleep(interval)


def load_model(filename, dtype):
    global stop_monitoring

    peak_memory = [0]  # Use a list to store peak memory so it can be updated in the thread
    initial_memory = log_memory_usage()
    print(f"Initial memory usage: {initial_memory:.2f} MB")

    monitor_thread = threading.Thread(target=monitor_memory, args=(0.01, peak_memory))
    monitor_thread.start()

    start_time = time.time()
    UNet2DConditionModel.from_single_file(filename, torch_dtype=dtype)
    end_time = time.time()

    stop_monitoring = True
    monitor_thread.join()  # Wait for the monitoring thread to finish

    print(f"Peak memory usage: {peak_memory[0]:.2f} MB")
    print(f"Time taken: {end_time - start_time:.2f} seconds")
    final_memory = log_memory_usage()
    print(f"Final memory usage: {final_memory:.2f} MB")


load_model(filename, torch.float8_e4m3fn)

Logs

Initial memory usage: 737.19 MB
Peak memory usage: 4867.43 MB
Time taken: 0.92 seconds
Final memory usage: 1578.99 MB

System Info

not relevant here

Who can help?

@DN6

asomoza avatar Jan 29 '25 14:01 asomoza

@asomoza can you test under this PR? https://github.com/huggingface/diffusers/pull/10604

yiyixuxu avatar Jan 29 '25 16:01 yiyixuxu

@yiyixuxu changed to windows and a mobile 4090 since the original issue was with windows.

Without the PR:

Initial memory usage: 548.29 MB
Peak memory usage: 4667.13 MB
Time taken: 1.49 seconds
Final memory usage: 1387.07 MB

With the PR:

Initial memory usage: 548.34 MB
Peak memory usage: 4668.90 MB
Time taken: 3.08 seconds
Final memory usage: 1388.89 MB

So RAM usage is almost the same, and not related to this issue, but that PR doubles the time for loading the model which is not good.

ccing @SunMarc just in case

asomoza avatar Jan 29 '25 17:01 asomoza

Adding a bit more context from the original discussion: The overhead is mostly caused by this issue: https://github.com/huggingface/safetensors/issues/542, but there are other problems that make it harder reduce RAM usage when loading a model. (See the feature request below)

The conversion itself uses a lot of RAM. Combined with the fact that there is no way (that I know of) to load the model in its original format from the file, there will be an overhead in most situations. See this example:

I'm using this file https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors, because it shows the issue a bit better. Instead of UNet2DConditionModel it needs to be loaded as FluxTransformer2DModel. The file mostly contains weights in bf16 format:

loading in fp8 format from a bf16 file: load_model(filename, torch.float8_e4m3fn)

Initial memory usage: 562.95 MB
Peak memory usage: 41511.62 MB
Time taken: 8.86 seconds
Final memory usage: 11774.39 MB

loading in fp32 format from a bf16 file: load_model(filename, torch.float8_e4m3fn)

Initial memory usage: 563.26 MB
Peak memory usage: 75561.89 MB
Time taken: 10.72 seconds
Final memory usage: 45595.88 MB

loading in bf6 format from a bf16 file: load_model(filename, torch.bfloat16) No conversion happens in this case, which is good. It also drastically reduces the peak RAM usage.

Initial memory usage: 562.68 MB
Peak memory usage: 14339.55 MB
Time taken: 2.94 seconds
Final memory usage: 7459.44 MB (not sure if these numbers are correct. It probably didn't load all the tensors because there is no read access. The final usage should be closer to 24GB)

loading in an unspecified format from a bf16 file: load_model(filename, None) The model is now loaded in fp32 format, which triggers a conversion, using a lot of RAM again. I would expect that the model is loaded with its original weights in bf16 format.

Initial memory usage: 562.57 MB
Peak memory usage: 75545.73 MB
Time taken: 10.43 seconds
Final memory usage: 45595.22 MB

Feature request for a new parameter to remove any conversion overhead

Additionally it would be great to have an option to not load the weights at all. This can be done by removing any read access to the tensors. The safetensors library already supports lazy tensor loading out of the box. Only tensors with a read access are actually loaded from the file. At the moment this is triggered by the .to() call that converts the weights. Having this option would make it possible to manually convert each tensor to a custom data type without any overhead. (apart from the issue linked above)

Nerogar avatar Jan 29 '25 22:01 Nerogar

Ive tested load this flux_dev model. On my machine when I have 46GB of free memory. This loading method load_from_single_file() when internally it loads the state_dict it loads it in bfloat16 in memory before converting. So my memory dropped to 32GB. Then in the step where it converted this to the diffusers format my memory dropped again to 20GB. And finally when it takes this and throws it in the meta device, only then will it convert it to FP8 so my memory drops to less than 4GB free and then it frees up the memory again.

elismasilva avatar Jan 30 '25 01:01 elismasilva

the ram used depend on file size loaded , so u must convert model to smaller size then load it ! or add command to free ram after loaded

al-swaiti avatar Feb 03 '25 15:02 al-swaiti

So RAM usage is almost the same, and not related to this issue, but that PR doubles the time for loading the model which is not good.

ccing @SunMarc just in case

It was due to a small mistake from me ! Sorry for that, I fixed it in the latest commit. Also the PR should only speed up loading for diffusion models for now.

SunMarc avatar Feb 04 '25 17:02 SunMarc

I got the same issue.

from diffusers import FluxTransformer2DModel
import torch
# The single file model is from https://civitai.com/api/download/models/1413133?type=Model&format=SafeTensor&size=full&fp=fp16
# When torch_dtype is bf16, it works, the RAM usage is 8.3GB.
transformer = FluxTransformer2DModel.from_single_file("...",torch_dtype=torch.bfloat16)

# When torch_dtype is fp16, OOM in 32GB RAM.
transformer = FluxTransformer2DModel.from_single_file("...",torch_dtype=torch.float16)

Consistent with what @Nerogar said.

CyberVy avatar Feb 28 '25 08:02 CyberVy

I've found an easy way to solve this issue. We can check the dtype before loading the model with .from_single_file, and then pass it into .from_single_file.

from diffusers.loaders.single_file_utils import load_state_dict

state_dict = load_state_dict("...")
torch_dtype = state_dict[list(state_dict.keys())[0]].dtype
transformer = FluxTransformer2DModel.from_single_file("...",torch_dtype=torch_dtype)

CyberVy avatar Feb 28 '25 10:02 CyberVy

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Mar 24 '25 15:03 github-actions[bot]

I've found an easy way to solve this issue. We can check the dtype before loading the model with .from_single_file, and then pass it into .from_single_file.

from diffusers.loaders.single_file_utils import load_state_dict

state_dict = load_state_dict("...") torch_dtype = state_dict[list(state_dict.keys())[0]].dtype transformer = FluxTransformer2DModel.from_single_file("...",torch_dtype=torch_dtype)

is this workaround not a solution to this issue, implemented in diffusers? When would you want to load the transformer in a precision other than the precision that the file has - especially in fp32?

[You might want to quantize during loading, but this is more difficult to implement. This workaround seems like an easy improvement to the current behaviour.]

dxqb avatar Apr 12 '25 07:04 dxqb

Models aren't always saved with a single dtype. Especially if you use something low precision like fp8. Some of the weights will still be saved in higher precision like fp32. So you need to be careful when detecting the dtype of the file.

Nerogar avatar Apr 12 '25 07:04 Nerogar

Documenting a workaround for peak memory we found for visibility:

tensors = {}
with safe_open(filename, framework="pt") as f:
    keys = list(f.keys())

for key in keys:
    with safe_open(filename, framework="pt") as f:
        tensors[key] = f.get_tensor(key).to(dtype)

Essentially safe_open holds the memory while under the context. This method is slower but peak memory usage matches the final memory usage.

Initial memory usage: 685.62 MB
Peak memory usage: 1714.91 MB
Time taken: 4.71 seconds
Final memory usage: 1714.95 MB

Compared to

Initial memory usage: 685.44 MB
Peak memory usage: 5771.81 MB
Time taken: 1.84 seconds
Final memory usage: 1714.28 MB
tensors = {}
with safe_open(filename, framework="pt") as f:
    keys = list(f.keys())
    for key in keys:
        tensors[key] = f.get_tensor(key).to(dtype)

No combination of del, cloning then del or even gc.collect after del will free the memory while under safe_open context.

hlky avatar Apr 14 '25 06:04 hlky

@hlky did you try this on Windows or only Linux? I remember trying something similar before, but immediately got a blue screen.

Nerogar avatar Apr 14 '25 06:04 Nerogar

@Nerogar On Windows

Image

hlky avatar Apr 14 '25 06:04 hlky