FlashPack
What does this PR do?
Adds support for FlashPack
FlashPack could be used as weight format only (see: load_flashpack_checkpoint) - keeping only weight format code would be a cleaner integration, the model loading is indeed faster[1] however part of the performance difference seems to be just due to the complexity of existing from_pretrained code, for example I noticed that _caching_allocator_warmup slows things down, also empty_device_cache is called in _load_pretrained_model whereas FlashPack's code doesn't thus the empty cache time is excluded from FlashPack's benchmark results.
- [x] Pipeline
- [x] Cleanup (
is_flashpack_availablecheck logging, etc) - [x] FlashPack arguments (
silent,num_streams,use_distributed_loadingetc.)
Benchmark
Changes from FlashPack's version:
- Use Diffusers
- Remove
transformersrelated code - Add SD v1.5
import csv
import gc
import os
import shutil
import tempfile
import time
import torch
from diffusers.models import AutoModel as DiffusersAutoModel
from huggingface_hub import snapshot_download
def test_model(
repo_id: str,
subfolder: str | None = None,
accelerate_device: str | torch.device = "cuda",
flashpack_device: str | torch.device = "cuda",
dtype: torch.dtype | None = None,
allow_pattern: str = "*",
) -> tuple[float, float, int]:
"""
Test a model from a repository.
"""
allow_patterns = [f"{subfolder}/{allow_pattern}"]
if allow_pattern != "*":
allow_patterns = [f"{subfolder}/{allow_pattern}", f"{subfolder}/config.json"]
repo_dir = snapshot_download(
repo_id, allow_patterns=None if subfolder is None else allow_patterns
)
model_dir = repo_dir if subfolder is None else os.path.join(repo_dir, subfolder)
saved_flashpack_path = os.path.join(model_dir, "model.flashpack")
saved_flashpack_config_path = os.path.join(model_dir, "flashpack_config.json")
with tempfile.TemporaryDirectory() as tmpdir:
# Make a new model directory with the model in it so it isn't cached
temp_model_dir = os.path.join(tmpdir, "model")
flashpack_dir = os.path.join(tmpdir, "flashpack")
os.makedirs(flashpack_dir, exist_ok=True)
print("Copying model to temporary directory")
shutil.copytree(model_dir, temp_model_dir)
# Load from the temporary model directory
print("Loading model from temporary directory using from_pretrained")
start_time = time.time()
model = DiffusersAutoModel.from_pretrained(
temp_model_dir,
torch_dtype=dtype,
device_map={"": accelerate_device},
variant="fp16" if allow_pattern != "*" else None,
)
end_time = time.time()
accelerate_time = end_time - start_time
print(f"Time taken with from_pretrained: {accelerate_time} seconds")
if os.path.exists(saved_flashpack_path) and os.path.exists(
saved_flashpack_config_path
):
print("Copying flashpack to temporary directory")
shutil.copy(
saved_flashpack_path, os.path.join(flashpack_dir, "model.flashpack")
)
shutil.copy(
saved_flashpack_config_path, os.path.join(flashpack_dir, "config.json")
)
else:
print("Packing model to flashpack")
pack_start_time = time.time()
model.save_pretrained(
flashpack_dir,
use_flashpack=True,
)
pack_end_time = time.time()
print(
f"Time taken with save_pretrained_flashpack: {pack_end_time - pack_start_time} seconds"
)
# Copy back to the original model directory
shutil.copy(
os.path.join(flashpack_dir, "model.flashpack"), saved_flashpack_path
)
shutil.copy(
os.path.join(flashpack_dir, "config.json"), saved_flashpack_config_path
)
del model
sync_and_flush()
print("Loading model from flashpack directory using from_pretrained_flashpack")
flashpack_start_time = time.time()
flashpack_model = DiffusersAutoModel.from_pretrained(
flashpack_dir,
torch_dtype=dtype,
device_map={"": flashpack_device},
use_flashpack=True,
)
flashpack_end_time = time.time()
flashpack_time = flashpack_end_time - flashpack_start_time
print(f"Time taken with from_pretrained_flashpack: {flashpack_time} seconds")
total_numel = 0
for param in flashpack_model.parameters():
total_numel += param.numel()
total_bytes = total_numel * dtype.itemsize
del flashpack_model
sync_and_flush()
return accelerate_time, flashpack_time, total_bytes
def test_wan_small_transformer() -> tuple[float, float, int]:
return test_model(
repo_id="Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
subfolder="transformer",
accelerate_device="cuda:0" if torch.cuda.is_available() else "cpu",
flashpack_device="cuda:0" if torch.cuda.is_available() else "cpu",
dtype=torch.bfloat16,
)
def test_wan_large_transformer() -> tuple[float, float, int]:
return test_model(
repo_id="Wan-AI/Wan2.1-T2V-14B-Diffusers",
subfolder="transformer",
accelerate_device="cuda:0" if torch.cuda.is_available() else "cpu",
flashpack_device="cuda:0" if torch.cuda.is_available() else "cpu",
dtype=torch.bfloat16,
)
def test_stable_diffusion_v1_5() -> tuple[float, float, int]:
return test_model(
repo_id="stable-diffusion-v1-5/stable-diffusion-v1-5",
subfolder="unet",
accelerate_device="cuda:0" if torch.cuda.is_available() else "cpu",
flashpack_device="cuda:0" if torch.cuda.is_available() else "cpu",
dtype=torch.float16,
allow_pattern="*.fp16.safetensors",
)
def test_flux_transformer() -> tuple[float, float, int]:
return test_model(
repo_id="black-forest-labs/FLUX.1-dev",
subfolder="transformer",
accelerate_device="cuda:0" if torch.cuda.is_available() else "cpu",
flashpack_device="cuda:0" if torch.cuda.is_available() else "cpu",
dtype=torch.bfloat16,
)
def print_test_result(
model_name: str,
accelerate_time: float,
flashpack_time: float,
total_bytes: int,
) -> None:
print(f"{model_name}: Accelerate time: {accelerate_time} seconds")
print(f"{model_name}: Flashpack time: {flashpack_time} seconds")
accelerate_gbps = (total_bytes / 1000**3) / accelerate_time
flashpack_gbps = (total_bytes / 1000**3) / flashpack_time
print(f"{model_name}: Accelerate GB/s: {accelerate_gbps} GB/s")
print(f"{model_name}: Flashpack GB/s: {flashpack_gbps} GB/s")
def sync_and_flush() -> None:
torch.cuda.empty_cache()
gc.collect()
os.system("sync")
if os.geteuid() == 0:
os.system("echo 3 | tee /proc/sys/vm/drop_caches")
if __name__ == "__main__":
with open("benchmark_results.csv", "a") as f:
writer = csv.writer(f)
writer.writerow(["model", "accelerate_time", "flashpack_time", "total_bytes"])
for i in range(10):
for test_model_name, test_func in [
("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", test_wan_small_transformer),
("stable-diffusion-v1-5/stable-diffusion-v1-5", test_stable_diffusion_v1_5),
# ("black-forest-labs/FLUX.1-dev", test_flux_transformer),
]:
accelerate_time, flashpack_time, total_bytes = test_func()
writer.writerow(
[test_model_name, accelerate_time, flashpack_time, total_bytes]
)
print_test_result(
test_model_name, accelerate_time, flashpack_time, total_bytes
)
======================================================================
SUMMARY STATISTICS
======================================================================
Model Size (GB) Accel (s) Flash (s) Speedup
----------------------------------------------------------------------
Stable Diffusion v1.5 (fp16) 1.60 0.250 0.263 0.95x
Wan2.1 1.3B DiT 2.64 1.315 0.600 2.19x
======================================================================
Model Accel GB/s Flash GB/s
----------------------------------------------------------------------
Stable Diffusion v1.5 (fp16) 6.41 6.13
Wan2.1 1.3B DiT 2.33 4.44
======================================================================
[1] For bfloat16 - with float16 existing code appears to be faster
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.