possibly to avoid `from_single_file` loading in fp32 to save RAM
Describe the bug
When loading a model using from_single_file(), the RAM usage is really high possibly because the weights are loaded in FP32 before conversion.
Reproduction
import threading
import time
import psutil
import torch
from huggingface_hub import hf_hub_download
from diffusers import UNet2DConditionModel
filename = hf_hub_download("stable-diffusion-v1-5/stable-diffusion-v1-5", filename="v1-5-pruned-emaonly.safetensors")
stop_monitoring = False
def log_memory_usage():
process = psutil.Process()
mem_info = process.memory_info()
return mem_info.rss / (1024**2) # Convert to MB
def monitor_memory(interval, peak_memory):
while not stop_monitoring:
current_memory = log_memory_usage()
peak_memory[0] = max(peak_memory[0], current_memory)
time.sleep(interval)
def load_model(filename, dtype):
global stop_monitoring
peak_memory = [0] # Use a list to store peak memory so it can be updated in the thread
initial_memory = log_memory_usage()
print(f"Initial memory usage: {initial_memory:.2f} MB")
monitor_thread = threading.Thread(target=monitor_memory, args=(0.01, peak_memory))
monitor_thread.start()
start_time = time.time()
UNet2DConditionModel.from_single_file(filename, torch_dtype=dtype)
end_time = time.time()
stop_monitoring = True
monitor_thread.join() # Wait for the monitoring thread to finish
print(f"Peak memory usage: {peak_memory[0]:.2f} MB")
print(f"Time taken: {end_time - start_time:.2f} seconds")
final_memory = log_memory_usage()
print(f"Final memory usage: {final_memory:.2f} MB")
load_model(filename, torch.float8_e4m3fn)
Logs
Initial memory usage: 737.19 MB
Peak memory usage: 4867.43 MB
Time taken: 0.92 seconds
Final memory usage: 1578.99 MB
System Info
not relevant here
Who can help?
@DN6
@asomoza can you test under this PR? https://github.com/huggingface/diffusers/pull/10604
@yiyixuxu changed to windows and a mobile 4090 since the original issue was with windows.
Without the PR:
Initial memory usage: 548.29 MB
Peak memory usage: 4667.13 MB
Time taken: 1.49 seconds
Final memory usage: 1387.07 MB
With the PR:
Initial memory usage: 548.34 MB
Peak memory usage: 4668.90 MB
Time taken: 3.08 seconds
Final memory usage: 1388.89 MB
So RAM usage is almost the same, and not related to this issue, but that PR doubles the time for loading the model which is not good.
ccing @SunMarc just in case
Adding a bit more context from the original discussion: The overhead is mostly caused by this issue: https://github.com/huggingface/safetensors/issues/542, but there are other problems that make it harder reduce RAM usage when loading a model. (See the feature request below)
The conversion itself uses a lot of RAM. Combined with the fact that there is no way (that I know of) to load the model in its original format from the file, there will be an overhead in most situations. See this example:
I'm using this file https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors, because it shows the issue a bit better. Instead of UNet2DConditionModel it needs to be loaded as FluxTransformer2DModel. The file mostly contains weights in bf16 format:
loading in fp8 format from a bf16 file: load_model(filename, torch.float8_e4m3fn)
Initial memory usage: 562.95 MB
Peak memory usage: 41511.62 MB
Time taken: 8.86 seconds
Final memory usage: 11774.39 MB
loading in fp32 format from a bf16 file: load_model(filename, torch.float8_e4m3fn)
Initial memory usage: 563.26 MB
Peak memory usage: 75561.89 MB
Time taken: 10.72 seconds
Final memory usage: 45595.88 MB
loading in bf6 format from a bf16 file: load_model(filename, torch.bfloat16) No conversion happens in this case, which is good. It also drastically reduces the peak RAM usage.
Initial memory usage: 562.68 MB
Peak memory usage: 14339.55 MB
Time taken: 2.94 seconds
Final memory usage: 7459.44 MB (not sure if these numbers are correct. It probably didn't load all the tensors because there is no read access. The final usage should be closer to 24GB)
loading in an unspecified format from a bf16 file: load_model(filename, None) The model is now loaded in fp32 format, which triggers a conversion, using a lot of RAM again. I would expect that the model is loaded with its original weights in bf16 format.
Initial memory usage: 562.57 MB
Peak memory usage: 75545.73 MB
Time taken: 10.43 seconds
Final memory usage: 45595.22 MB
Feature request for a new parameter to remove any conversion overhead
Additionally it would be great to have an option to not load the weights at all. This can be done by removing any read access to the tensors. The safetensors library already supports lazy tensor loading out of the box. Only tensors with a read access are actually loaded from the file. At the moment this is triggered by the .to() call that converts the weights. Having this option would make it possible to manually convert each tensor to a custom data type without any overhead. (apart from the issue linked above)
Ive tested load this flux_dev model. On my machine when I have 46GB of free memory. This loading method load_from_single_file() when internally it loads the state_dict it loads it in bfloat16 in memory before converting. So my memory dropped to 32GB. Then in the step where it converted this to the diffusers format my memory dropped again to 20GB. And finally when it takes this and throws it in the meta device, only then will it convert it to FP8 so my memory drops to less than 4GB free and then it frees up the memory again.
the ram used depend on file size loaded , so u must convert model to smaller size then load it ! or add command to free ram after loaded
So RAM usage is almost the same, and not related to this issue, but that PR doubles the time for loading the model which is not good.
ccing @SunMarc just in case
It was due to a small mistake from me ! Sorry for that, I fixed it in the latest commit. Also the PR should only speed up loading for diffusion models for now.
I got the same issue.
from diffusers import FluxTransformer2DModel
import torch
# The single file model is from https://civitai.com/api/download/models/1413133?type=Model&format=SafeTensor&size=full&fp=fp16
# When torch_dtype is bf16, it works, the RAM usage is 8.3GB.
transformer = FluxTransformer2DModel.from_single_file("...",torch_dtype=torch.bfloat16)
# When torch_dtype is fp16, OOM in 32GB RAM.
transformer = FluxTransformer2DModel.from_single_file("...",torch_dtype=torch.float16)
Consistent with what @Nerogar said.
I've found an easy way to solve this issue.
We can check the dtype before loading the model with .from_single_file, and then pass it into .from_single_file.
from diffusers.loaders.single_file_utils import load_state_dict
state_dict = load_state_dict("...")
torch_dtype = state_dict[list(state_dict.keys())[0]].dtype
transformer = FluxTransformer2DModel.from_single_file("...",torch_dtype=torch_dtype)
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
I've found an easy way to solve this issue. We can check the dtype before loading the model with
.from_single_file, and then pass it into.from_single_file.from diffusers.loaders.single_file_utils import load_state_dict
state_dict = load_state_dict("...") torch_dtype = state_dict[list(state_dict.keys())[0]].dtype transformer = FluxTransformer2DModel.from_single_file("...",torch_dtype=torch_dtype)
is this workaround not a solution to this issue, implemented in diffusers? When would you want to load the transformer in a precision other than the precision that the file has - especially in fp32?
[You might want to quantize during loading, but this is more difficult to implement. This workaround seems like an easy improvement to the current behaviour.]
Models aren't always saved with a single dtype. Especially if you use something low precision like fp8. Some of the weights will still be saved in higher precision like fp32. So you need to be careful when detecting the dtype of the file.
Documenting a workaround for peak memory we found for visibility:
tensors = {}
with safe_open(filename, framework="pt") as f:
keys = list(f.keys())
for key in keys:
with safe_open(filename, framework="pt") as f:
tensors[key] = f.get_tensor(key).to(dtype)
Essentially safe_open holds the memory while under the context. This method is slower but peak memory usage matches the final memory usage.
Initial memory usage: 685.62 MB
Peak memory usage: 1714.91 MB
Time taken: 4.71 seconds
Final memory usage: 1714.95 MB
Compared to
Initial memory usage: 685.44 MB
Peak memory usage: 5771.81 MB
Time taken: 1.84 seconds
Final memory usage: 1714.28 MB
tensors = {}
with safe_open(filename, framework="pt") as f:
keys = list(f.keys())
for key in keys:
tensors[key] = f.get_tensor(key).to(dtype)
No combination of del, cloning then del or even gc.collect after del will free the memory while under safe_open context.
@hlky did you try this on Windows or only Linux? I remember trying something similar before, but immediately got a blue screen.
@Nerogar On Windows