datasets
datasets copied to clipboard
I train controlnet_sdxl in bf16 datatype, got unsupported ERROR in datasets
Describe the bug
Traceback (most recent call last):
File "train_controlnet_sdxl.py", line 1252, in <module>
main(args)
File "train_controlnet_sdxl.py", line 1013, in main
train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint)
File "/home/miniconda3/envs/mhh_df/lib/python3.8/site-packages/datasets/arrow_dataset.py", line 592, in wrapper
out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
File "/home/miniconda3/envs/mhh_df/lib/python3.8/site-packages/datasets/arrow_dataset.py", line 557, in wrapper
out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
File "/home/miniconda3/envs/mhh_df/lib/python3.8/site-packages/datasets/arrow_dataset.py", line 3093, in map
for rank, done, content in Dataset._map_single(**dataset_kwargs):
File "/home/miniconda3/envs/mhh_df/lib/python3.8/site-packages/datasets/arrow_dataset.py", line 3489, in _map_single
writer.write_batch(batch)
File "/home/miniconda3/envs/mhh_df/lib/python3.8/site-packages/datasets/arrow_writer.py", line 557, in write_batch
arrays.append(pa.array(typed_sequence))
File "pyarrow/array.pxi", line 248, in pyarrow.lib.array
File "pyarrow/array.pxi", line 113, in pyarrow.lib._handle_arrow_array_protocol
File "/home/miniconda3/envs/mhh_df/lib/python3.8/site-packages/datasets/arrow_writer.py", line 191, in __arrow_array__
out = pa.array(cast_to_python_objects(data, only_1d_for_numpy=True))
File "/home/miniconda3/envs/mhh_df/lib/python3.8/site-packages/datasets/features/features.py", line 447, in cast_to_python_objects
return _cast_to_python_objects(
File "/home/miniconda3/envs/mhh_df/lib/python3.8/site-packages/datasets/features/features.py", line 324, in _cast_to_python_objects
for x in obj.detach().cpu().numpy()
TypeError: Got unsupported ScalarType BFloat16
Steps to reproduce the bug
Here is my train script I use BF16 type,I use diffusers train my model
export MODEL_DIR="/home/mhh/sd_models/stable-diffusion-xl-base-1.0"
export OUTPUT_DIR="./control_net"
export VAE_NAME="/home/mhh/sd_models/sdxl-vae-fp16-fix"
accelerate launch train_controlnet_sdxl.py \
--pretrained_model_name_or_path=$MODEL_DIR \
--output_dir=$OUTPUT_DIR \
--pretrained_vae_model_name_or_path=$VAE_NAME \
--dataset_name=/home/mhh/sd_datasets/fusing/fill50k \
--mixed_precision="bf16" \
--resolution=1024 \
--learning_rate=1e-5 \
--max_train_steps=200 \
--validation_image "/home/mhh/sd_datasets/controlnet_image/conditioning_image_1.png" "/home/mhh/sd_datasets/controlnet_image/conditioning_image_2.png" \
--validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
--validation_steps=50 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--report_to="wandb" \
--seed=42 \
Expected behavior
When I changed the data type to fp16, it worked.
Environment info
datasets 2.16.1 numpy 1.24.4
I also see the same error and get passed it by casting that line to float.
so for x in obj.detach().cpu().numpy()
becomes for x in obj.detach().to(torch.float).cpu().numpy()
I got the idea from this PR where someone was facing a similar issue (in a different repository). I guess numpy doesn't support bfloat16.