sd-scripts icon indicating copy to clipboard operation
sd-scripts copied to clipboard

Collecting image sizes is very slow.

Open markrmiller opened this issue 6 months ago • 7 comments

When I train with 4000 images, it takes forever to start training because just collecting the image sizes from the npz filenames takes 20 minutes from the drive where I have the data.

This code takes...I didn't measure, but maybe 10 seconds?

import os
import re
import glob
from typing import List, Tuple, Optional, Dict
from tqdm import tqdm
import multiprocessing as mp
from functools import partial

# Compile the regex pattern once
size_pattern = re.compile(r'_(\d+)x(\d+)(?:_flux\.npz|\.npz)$')
FLUX_LATENTS_NPZ_SUFFIX = "_flux.npz"


def parse_size_from_filename(filename: str) -> Tuple[Optional[int], Optional[int]]:
    match = size_pattern.search(filename)
    if match:
        return int(match.group(1)), int(match.group(2))
    logger.warning(f"Failed to parse size from filename: {filename}")
    return None, None


def get_all_cache_files(img_paths: List[str]) -> Dict[str, str]:
    cache_files = {}
    base_dirs = set(os.path.dirname(path) for path in img_paths)

    for base_dir in base_dirs:
        for file in os.listdir(base_dir):
            if file.endswith(FLUX_LATENTS_NPZ_SUFFIX):
                # Remove the size and suffix to create the key
                base_name = re.sub(r'_\d+x\d+_flux\.npz$', '', file)
                cache_files[os.path.join(base_dir, base_name)] = file

    return cache_files


def process_batch(batch: List[str], cache_files: Dict[str, str]) -> List[Tuple[Optional[int], Optional[int]]]:
    results = []
    for img_path in batch:
        base_path = os.path.splitext(img_path)[0]
        if base_path in cache_files:
            results.append(parse_size_from_filename(cache_files[base_path]))
        else:
            #results.append((None, None))
            raise FileNotFoundError(f"Cache file not found for {img_path}")
    return results


def get_image_sizes_from_cache_files(img_paths: List[str]) -> List[Tuple[Optional[int], Optional[int]]]:
    cache_files = get_all_cache_files(img_paths)

    num_cores = mp.cpu_count()
    batch_size = max(1, len(img_paths) // (num_cores * 4))  # Adjust batch size for better load balancing
    batches = [img_paths[i:i + batch_size] for i in range(0, len(img_paths), batch_size)]

    with mp.Pool(num_cores) as pool:
        process_func = partial(process_batch, cache_files=cache_files)
        results = list(tqdm(
            pool.imap(process_func, batches),
            total=len(batches),
            desc="Processing image batches"
        ))

    # Flatten the results
    return [size for batch in results for size in batch]

markrmiller avatar Aug 22 '24 01:08 markrmiller