k-diffusion
k-diffusion copied to clipboard
AttributeError in Oxford Flowers Demo: 'dict' object has no attribute 'convert'
Hi,
First, congratulations on this amazing repository; it's a great codebase for study.
I'm attempting to run the Oxford Flowers demo but am encountering an error. I'm not sure if it's just me or if the demo is currently non-functional (perhaps due to some changes on Hugging Face's end, or something else I'm unaware of):
images = [transform(image.convert(mode)) for image in examples[image_key]]
AttributeError: 'dict' object has no attribute 'convert'
Here's the full traceback:
Traceback (most recent call last):
File "/notebooks/k-diffusion/train.py", line 525, in <module>
main()
File "/notebooks/k-diffusion/train.py", line 435, in main
for batch in tqdm(train_dl, smoothing=0.1, disable=not accelerator.is_main_process):
File "/usr/local/lib/python3.9/dist-packages/tqdm/std.py", line 1195, in __iter__
for obj in iterable:
File "/usr/local/lib/python3.9/dist-packages/accelerate/data_loader.py", line 451, in __iter__
current_batch = next(dataloader_iter)
File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/dataloader.py", line 630, in __next__
data = self._next_data()
File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/dataloader.py", line 1345, in _next_data
return self._process_data(data)
File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/dataloader.py", line 1371, in _process_data
data.reraise()
File "/usr/local/lib/python3.9/dist-packages/torch/_utils.py", line 694, in reraise
raise exception
AttributeError: Caught AttributeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
data = fetcher.fetch(index)
File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/usr/local/lib/python3.9/dist-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/usr/local/lib/python3.9/dist-packages/datasets/arrow_dataset.py", line 2165, in __getitem__
return self._getitem(
File "/usr/local/lib/python3.9/dist-packages/datasets/arrow_dataset.py", line 2150, in _getitem
formatted_output = format_table(
File "/usr/local/lib/python3.9/dist-packages/datasets/formatting/formatting.py", line 532, in format_table
return formatter(pa_table, query_type=query_type)
File "/usr/local/lib/python3.9/dist-packages/datasets/formatting/formatting.py", line 281, in __call__
return self.format_row(pa_table)
File "/usr/local/lib/python3.9/dist-packages/datasets/formatting/formatting.py", line 387, in format_row
formatted_batch = self.format_batch(pa_table)
File "/usr/local/lib/python3.9/dist-packages/datasets/formatting/formatting.py", line 418, in format_batch
return self.transform(batch)
File "/notebooks/k-diffusion/k_diffusion/utils.py", line 39, in hf_datasets_augs_helper
images = [transform(image.convert(mode)) for image in examples[image_key]]
File "/notebooks/k-diffusion/k_diffusion/utils.py", line 39, in <listcomp>
images = [transform(image.convert(mode)) for image in examples[image_key]]
AttributeError: 'dict' object has no attribute 'convert'
OK, I think I found the problem.
from datasets import load_dataset
dataset = load_dataset("nelorth/oxford-flowers")
image = dataset["train"][0]["image"]
print(image.keys())
gives you:
dict_keys(['bytes', 'path'])
This breaks hf_datasets_augs_helper
that expect dataset["train"][0]["image"]
to be a PIL Image:
def hf_datasets_augs_helper(examples, transform, image_key, mode='RGB'):
"""Apply passed in transforms for HuggingFace Datasets."""
images = [transform(image.convert(mode)) for image in examples[image_key]]
return {image_key: images}
A fix for this dataset would be:
import io
from PIL import Image
def hf_datasets_augs_helper(examples, transform, image_key, mode='RGB'):
"""Apply passed in transforms for HuggingFace Datasets."""
images = [transform(Image.open(io.BytesIO(image["bytes"])).convert(mode)) for image in examples[image_key]]
return {image_key: images}