sd-scripts
sd-scripts copied to clipboard
Collecting image sizes is very slow.
When I train with 4000 images, it takes forever to start training because just collecting the image sizes from the npz filenames takes 20 minutes from the drive where I have the data.
This code takes...I didn't measure, but maybe 10 seconds?
import os
import re
import glob
from typing import List, Tuple, Optional, Dict
from tqdm import tqdm
import multiprocessing as mp
from functools import partial
# Compile the regex pattern once
size_pattern = re.compile(r'_(\d+)x(\d+)(?:_flux\.npz|\.npz)$')
FLUX_LATENTS_NPZ_SUFFIX = "_flux.npz"
def parse_size_from_filename(filename: str) -> Tuple[Optional[int], Optional[int]]:
match = size_pattern.search(filename)
if match:
return int(match.group(1)), int(match.group(2))
logger.warning(f"Failed to parse size from filename: {filename}")
return None, None
def get_all_cache_files(img_paths: List[str]) -> Dict[str, str]:
cache_files = {}
base_dirs = set(os.path.dirname(path) for path in img_paths)
for base_dir in base_dirs:
for file in os.listdir(base_dir):
if file.endswith(FLUX_LATENTS_NPZ_SUFFIX):
# Remove the size and suffix to create the key
base_name = re.sub(r'_\d+x\d+_flux\.npz$', '', file)
cache_files[os.path.join(base_dir, base_name)] = file
return cache_files
def process_batch(batch: List[str], cache_files: Dict[str, str]) -> List[Tuple[Optional[int], Optional[int]]]:
results = []
for img_path in batch:
base_path = os.path.splitext(img_path)[0]
if base_path in cache_files:
results.append(parse_size_from_filename(cache_files[base_path]))
else:
#results.append((None, None))
raise FileNotFoundError(f"Cache file not found for {img_path}")
return results
def get_image_sizes_from_cache_files(img_paths: List[str]) -> List[Tuple[Optional[int], Optional[int]]]:
cache_files = get_all_cache_files(img_paths)
num_cores = mp.cpu_count()
batch_size = max(1, len(img_paths) // (num_cores * 4)) # Adjust batch size for better load balancing
batches = [img_paths[i:i + batch_size] for i in range(0, len(img_paths), batch_size)]
with mp.Pool(num_cores) as pool:
process_func = partial(process_batch, cache_files=cache_files)
results = list(tqdm(
pool.imap(process_func, batches),
total=len(batches),
desc="Processing image batches"
))
# Flatten the results
return [size for batch in results for size in batch]