sd-scripts
sd-scripts copied to clipboard
Cache latents optionaly
It would be great to make caching the image latents optional. It takes 12 hours to convert 100k images with left/right flip enabled on a 3090 :/ (2.6 it/s). If I'm doing quick test runs on 1 or 2 epocs this is not worth the time currently.
If that is too much work, something as simple as improving the speed would be fine as well, as my gpu is never at 100% while processing.
I think your case is for prepare_buckets_latents.py (train_db.py has the option to disable/enable the caching).
You can use a larger batch size for the script, it might improve the speed. But the processing is CPU bound, the improvement might be small.
I have a plan to refactor Dataset for train_db.py and fine_tune.py, and the caching (npz creating) will become optional in the refactoring. Please wait a moment.
Currently have batch size as high as it can go without OOM. And yes, the native trainer is the one I am talking about.
Good to hear, thanks for the reply.
For now I got a 2-3x speed up when caching latents by adding torch dataset/dataloader that loads the image files asynchronously. You could go even further and pre-encode them, but I wanted to change as little code as possible for now.
## prepare_buckets_latents.py
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as TF
...
class CustomDataset(Dataset):
def __init__(self, image_paths):
self.images = image_paths
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = self.images[idx]
try:
image = Image.open(img_path).convert("RGB")
# convert to tensor temporarily so dataloader will accept it
tensor_pil = TF.pil_to_tensor(image)
except Exception as e:
print('Could not load image path:', img_path, ', error:', e)
return None
return (tensor_pil, img_path)
def collate_fn_remove_corrupted(batch):
"""Collate function that allows to remove corrupted examples in the
dataloader. It expects that the dataloader returns 'None' when that occurs.
The 'None's in the batch are removed.
"""
# Filter out all the Nones (corrupted examples)
batch = list(filter(lambda x: x is not None, batch))
return batch
...
BATCH_SIZE = 8
dataset = CustomDataset(image_paths)
dataloader = DataLoader(dataset,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=16,
collate_fn=collate_fn_remove_corrupted)
total_batches = dataset.__len__() // BATCH_SIZE
print("total_batches ", total_batches)
for idx, batch in enumerate(tqdm(dataloader, smoothing=0.1)):
for i in range(len(batch)):
# convert back to PIL image
image = TF.to_pil_image(batch[i][0])
image_path = batch[i][1]
...
# this might be off by one, but I don't mind
is_last = idx >= total_batches and i >= len(batch)
This feature is supported with #49.
Thank you for the code for faster loading! I will update the preprocessing scripts in near future, and I will refer to this :)