diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

Run FLUX-controlnet zero3 training failed: 'weight' must be 2-D

Open alien-0119 opened this issue 10 months ago • 14 comments

Describe the bug

I am attempting to use Zero-3 for Flux Controlnet training on 8 GPUs following the guidance of README. The error below occured:

[rank0]: RuntimeError: 'weight' must be 2-D

Reproduction

accelerate config:

compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
  gradient_accumulation_steps: 8
  offload_optimizer_device: cpu
  offload_param_device: cpu
  zero3_init_flag: true
  zero3_save_16bit_model: true
  zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

training command:

accelerate  launch  --config_file "./accelerate_config_zero3.yaml" train_controlnet_flux_zero3.py --pretrained_model_name_or_path=/srv/mindone/wty/flux.1-dev/ --jsonl_for_train=/srv/mindone/wty/diffusers/examples/controlnet/train_1000.jsonl --conditioning_image_column=conditioning_image --image_column=image --caption_column=text --output_dir=/srv/mindone/wty/diffusers/examples/controlnet/single_layer --mixed_precision="bf16" --resolution=512 --learning_rate=1e-5 --max_train_steps=100 --train_batch_size=1 --gradient_accumulation_steps=8 --num_double_layers=4 --num_single_layers=0 --seed=42 --gradient_checkpointing --cache_dir=/srv/mindone/wty/diffusers/examples/controlnet/cache --dataloader_num_workers=8 --resume_from_checkpoint="latest"

Logs

Map:   0%|                                                                                                                                                  | 0/1000 [00:00<?, ? examples/s]
[rank0]: Traceback (most recent call last):
[rank0]:   File "/srv/mindone/wty/diffusers/examples/controlnet/train_controlnet_flux_zero3.py", line 1481, in <module>
[rank0]:     main(args)
[rank0]:   File "/srv/mindone/wty/diffusers/examples/controlnet/train_controlnet_flux_zero3.py", line 1182, in main
[rank0]:     train_dataset = train_dataset.map(
[rank0]:   File "/home/miniconda3/envs/flux-perf/lib/python3.9/site-packages/datasets/arrow_dataset.py", line 562, in wrapper
[rank0]:     out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
[rank0]:   File "/home/miniconda3/envs/flux-perf/lib/python3.9/site-packages/datasets/arrow_dataset.py", line 3079, in map
[rank0]:     for rank, done, content in Dataset._map_single(**dataset_kwargs):
[rank0]:   File "/home/miniconda3/envs/flux-perf/lib/python3.9/site-packages/datasets/arrow_dataset.py", line 3519, in _map_single
[rank0]:     for i, batch in iter_outputs(shard_iterable):
[rank0]:   File "/home/miniconda3/envs/flux-perf/lib/python3.9/site-packages/datasets/arrow_dataset.py", line 3469, in iter_outputs
[rank0]:     yield i, apply_function(example, i, offset=offset)
[rank0]:   File "/home/miniconda3/envs/flux-perf/lib/python3.9/site-packages/datasets/arrow_dataset.py", line 3392, in apply_function
[rank0]:     processed_inputs = function(*fn_args, *additional_args, **fn_kwargs)
[rank0]:   File "/srv/mindone/wty/diffusers/examples/controlnet/train_controlnet_flux_zero3.py", line 1094, in compute_embeddings
[rank0]:     prompt_embeds, pooled_prompt_embeds, text_ids = flux_controlnet_pipeline.encode_prompt(
[rank0]:   File "/srv/mindone/wty/diffusers/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py", line 396, in encode_prompt
[rank0]:     pooled_prompt_embeds = self._get_clip_prompt_embeds(
[rank0]:   File "/srv/mindone/wty/diffusers/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py", line 328, in _get_clip_prompt_embeds
[rank0]:     prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
[rank0]:   File "/home/miniconda3/envs/flux-perf/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/miniconda3/envs/flux-perf/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/miniconda3/envs/flux-perf/lib/python3.9/site-packages/transformers/models/clip/modeling_clip.py", line 1056, in forward
[rank0]:     return self.text_model(
[rank0]:   File "/home/miniconda3/envs/flux-perf/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/miniconda3/envs/flux-perf/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/miniconda3/envs/flux-perf/lib/python3.9/site-packages/transformers/models/clip/modeling_clip.py", line 947, in forward
[rank0]:     hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
[rank0]:   File "/home/miniconda3/envs/flux-perf/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/miniconda3/envs/flux-perf/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/miniconda3/envs/flux-perf/lib/python3.9/site-packages/transformers/models/clip/modeling_clip.py", line 292, in forward
[rank0]:     inputs_embeds = self.token_embedding(input_ids)
[rank0]:   File "/home/miniconda3/envs/flux-perf/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/miniconda3/envs/flux-perf/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/miniconda3/envs/flux-perf/lib/python3.9/site-packages/torch/nn/modules/sparse.py", line 190, in forward
[rank0]:     return F.embedding(
[rank0]:   File "/home/miniconda3/envs/flux-perf/lib/python3.9/site-packages/torch/nn/functional.py", line 2551, in embedding
[rank0]:     return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
[rank0]: RuntimeError: 'weight' must be 2-D

System Info

  • 🤗 Diffusers version: 0.33.0.dev0(HEAD on #10945)
  • Platform: Linux-4.15.0-156-generic-x86_64-with-glibc2.27
  • Running on Google Colab?: No
  • Python version: 3.9.21
  • PyTorch version (GPU?): 2.6.0+cu124 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.29.1
  • Transformers version: 4.49.0
  • Accelerate version: 1.4.0
  • PEFT version: not installed
  • Bitsandbytes version: not installed
  • Safetensors version: 0.5.3
  • xFormers version: not installed
  • Accelerator: NVIDIA A100-SXM4-80GB, 81920 MiB NVIDIA A100-SXM4-80GB, 81920 MiB NVIDIA A100-SXM4-80GB, 81920 MiB NVIDIA A100-SXM4-80GB, 81920 MiB NVIDIA A100-SXM4-80GB, 81920 MiB NVIDIA A100-SXM4-80GB, 81920 MiB NVIDIA A100-SXM4-80GB, 81920 MiB NVIDIA A100-SXM4-80GB, 81920 MiB

Who can help?

@yiyixuxu @sayakpaul

alien-0119 avatar Mar 05 '25 02:03 alien-0119

have the same issue

haofanwang avatar Mar 15 '25 14:03 haofanwang

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Apr 17 '25 15:04 github-actions[bot]

train_controlnet_flux_zero3.py -- is this a modified script?

sayakpaul avatar Apr 17 '25 15:04 sayakpaul

train_controlnet_flux_zero3.py -- is this a modified script?

Yes. But I just modified it following the guidance of README.

alien-0119 avatar Apr 18 '25 07:04 alien-0119

Can I see the modifications you did?

sayakpaul avatar Apr 18 '25 07:04 sayakpaul

Can I see the modifications you did?

yes. This is my script.

train_controlnet_flux.zip

alien-0119 avatar Apr 21 '25 03:04 alien-0119

We cannot control the changes that lie outside of our scripts. Can you try using the official script and let me know the errors you face when launched with ZERO3?

sayakpaul avatar Apr 21 '25 04:04 sayakpaul

We cannot control the changes that lie outside of our scripts. Can you try using the official script and let me know the errors you face when launched with ZERO3?

I didn't change the script with my own consideration. I just changed the script according to the the guidance of written on README.

https://github.com/huggingface/diffusers/blob/main/examples/controlnet/README_flux.md#2precompute-all-inputs-latent-embeddings https://github.com/huggingface/diffusers/blob/main/examples/controlnet/README_flux.md#3redefine-the-behavior-of-getting-batchsize

alien-0119 avatar Apr 21 '25 06:04 alien-0119

Cc: @PromeAIpro

sayakpaul avatar Apr 21 '25 08:04 sayakpaul

follow

Johnson-yue avatar Apr 25 '25 07:04 Johnson-yue

Cc: @PromeAIpro

using deepseek were experimental tries we had tested when writing this script, it was not included in code because it is not guaranteed work on any gpu device, So we write some tips in readme for referer. Open to further contributes if tested good. @alien-0119 @Johnson-yue

PromeAIpro avatar Apr 28 '25 07:04 PromeAIpro

I gave up using deepspeed for training because I couldn't get good results.

Johnson-yue avatar Apr 28 '25 08:04 Johnson-yue

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar May 22 '25 15:05 github-actions[bot]

follow

luchaoqi avatar Jun 24 '25 17:06 luchaoqi

Please, do you have solved this error?

Daryu-Fan avatar Sep 01 '25 19:09 Daryu-Fan

[email protected] That's my email address, please feel free to contact with me if convenient.

Daryu-Fan avatar Sep 01 '25 19:09 Daryu-Fan