diffusers
diffusers copied to clipboard
train_controlnet.py broken local data loading
Describe the bug
The script examples/controlnet/train_controlnet.py
is designed to either take a HF dataset with --dataset_name
(as illustrated here) or local data --train_data_dir
, but the latter case fails out-of-the-box.
Reproduction
Running script with flags --train_data_dir
, --image_column
, --conditioning_image_column
, --caption_column
referring to dataset in standard HF format with metadata.csv
file in the data dir. Fails for two reasons:
-
metadata.csv
must have column namedfile_name
containing the filename of the target image files, but flag must be set to--image_column=image
and not--image_column=file_name
, which is an undocumented bug. - The script fails to load the images from the conditioning image column and instead treats the strings as PIL Images causing an error. I had to change the current line 702 from:
conditioning_images = [image.convert("RGB") for image in examples[conditioning_image_column]]
to:conditioning_images = [Image.open(image).convert("RGB") for image in examples[conditioning_image_column]]
for the script to run.
Logs
No response
System Info
-
diffusers
version: 0.27.0.dev0 - Platform: Linux-5.15.0-82-generic-x86_64-with-glibc2.35
- Python version: 3.11.4
- PyTorch version (GPU?): 2.2.1+cu121 (True)
- Huggingface_hub version: 0.21.3
- Transformers version: 4.38.0.dev0
- Accelerate version: 0.25.0
- xFormers version: 0.0.24
- Using GPU in script?: yes
- Using distributed or parallel set-up in script?: no
Who can help?
@sayakpaul @yiyixuxu @DN6
Feel free to send over a PR to fix :)
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
Some issue. But after I modify the code to read the images, it has one dimension error like this (batch_size=2):
RuntimeError: Given groups=1, weight of size [16, 3, 3, 3], expected input[2, 1, 512, 512] to have 3 channels