deepspeed train flux1 dreambooth lora can not save model
Describe the bug
when I run the script train_dreambooth_lora_flux.py. It raise ValueError: unexpected save model: <class 'deepspeed.runtime.engine.DeepSpeedEngine'>. something bug in save_model_hook?
Reproduction
accelerate launch train_dreambooth_lora_flux_custom.py
--pretrained_model_name_or_path=$MODEL_NAME
--instance_data_dir=$INSTANCE_DIR
--output_dir=$OUTPUT_DIR
--mixed_precision="bf16"
--instance_prompt="bedroom, YF_CN style"
--resolution=1024
--train_batch_size=1
--guidance_scale=1
--gradient_accumulation_steps=4
--optimizer="prodigy"
--learning_rate=1.
--report_to="tensorboard"
--lr_scheduler="constant"
--lr_warmup_steps=0
--num_train_epochs=30
--validation_prompt="bedroom, YF_CN style"
--validation_epochs=80
--checkpointing_steps=500
--seed="0"
--gradient_checkpointing
--use_8bit_adam
--rank=4
Logs
No response
System Info
torch==2.3.1 accelerate==0.34.2 deepspeed==0.15.1+8ac42ed7 diffusers==0.31.0.dev0
default_config.yaml as follow:
compute_environment: LOCAL_MACHINE debug: true deepspeed_config: gradient_accumulation_steps: 1 gradient_clipping: 1.0 offload_optimizer_device: none offload_param_device: none zero3_init_flag: false zero_stage: 2 distributed_type: DEEPSPEED downcast_bf16: 'no' enable_cpu_affinity: false machine_rank: 0 main_training_function: main mixed_precision: fp16 num_machines: 1 num_processes: 1 rdzv_backend: static same_network: true tpu_env: [] tpu_use_cluster: false tpu_use_sudo: false use_cpu: fals
Who can help?
@sayakpaul
Refer to https://github.com/huggingface/accelerate/issues/2787 to get an idea of the adjustments needed to make it work.
Refer to huggingface/accelerate#2787 to get an idea of the adjustments needed to make it work.
if isinstance(unwrap_model(model), type(unwrap_model(transformer)))inplaceif isinstance(model, type(unwrap_model(transformer)))can save the checkpoint。 However, this script still has many errors, including but not limited to being unable to load the lora trained in deepspeed but pytorch only can work, activating train_text_decoder, and accelerating the initialization of multiple models. I don't know why the index start from 1 instead of 0 which raise out of range of list:
params_to_optimize[1]["lr"] = args.learning_rate
params_to_optimize[2]["lr"] = args.learning_rate
@sayakpaul
and
However, this script still has many errors, including but not limited to being unable to load the lora trained in deepspeed but pytorch only can work, activating train_text_decoder, and accelerating the initialization of multiple models. I don't know why the index start from 1 instead of 0 which raise out of range of list:
It would have more helpful if provided more information on how you're launching the training experiments, etc. We already test if we're able to resume training:
https://github.com/huggingface/diffusers/blob/b9e2f886cd6e9182f1bf1bf7421c6363956f94c5/examples/dreambooth/test_dreambooth_flux.py#L65
activating train_text_decoder, and accelerating the initialization of multiple models. I don't know why the index start from 1 instead of 0 which raise out of range of list:
This I don't understand. Please elaborate so that we can provide further suggestions.
it seems as if no --train_text_encoder found in: diffusers/examples/dreambooth/test_dreambooth_flux.py
my script as follow:
accelerate launch train_dreambooth_lora_flux.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--mixed_precision="bf16" \
--instance_prompt="bedroom, YF_CN style" \
--resolution=1024 \
--train_batch_size=1 \
--guidance_scale=1 \
--gradient_accumulation_steps=4 \
--optimizer="prodigy" \
--learning_rate=1. \
--report_to="tensorboard" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--num_train_epochs=20 \
--validation_prompt="bedroom, YF_CN style" \
--validation_epochs=80 \
--checkpointing_steps=10 \
--seed="0" \
--gradient_checkpointing \
--use_8bit_adam \
--dataloader_num_workers=1 \
--train_text_encoder \
--rank=4
@sayakpaul and inference script:
import torch
from diffusers import FluxPipeline, FluxTransformer2DModel
import numpy as np
import os
from PIL import Image
from tqdm import tqdm
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16,
)
lora_path = "lora_path"
pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors")
pipe.fuse_lora()
pipe.to(torch.device('cuda'))
width=1024
height=800
prompt = "a bedroom with a bed, 2 night stand and a wardrobe,a bay window on the right, YF_CN style"
images = []
for i in range(10):
generator = torch.manual_seed(i)
image = pipe(prompt=prompt,
num_inference_steps=20, width=width, height=height, generator=generator
).images[0]
images.append(np.asarray(image))
image = Image.fromarray(np.vstack(images))
image.save("test.jpg")
pytorch_lora_weights.zip deepspeed config:
compute_environment: LOCAL_MACHINE
debug: true
deepspeed_config:
gradient_accumulation_steps: 1
gradient_clipping: 1.0
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
Okay so, it fails for train_text_encoder or does it fail without train_text_encoder as well?
Okay so, it fails for
train_text_encoderor does it fail withouttrain_text_encoderas well?
only fails for train_text_encoder
Okay that is helpful.
The error you posted in https://github.com/huggingface/diffusers/issues/9393#issuecomment-2342750651, seems easy to solve. We should just filter out the "module" keys in the state dict and it should work. Can you try that out first?
What errors do you see in the text encoder training?
errors caused by accelerate and deepspeed like
Oh that I am not sure about then. Ccing @muellerzr for advice.
Okay that is helpful.
The error you posted in #9393 (comment), seems easy to solve. We should just filter out the "module" keys in the state dict and it should work. Can you try that out first?
the lora trained by deepspeed, i filter out the "module" in keys, and it could work as same as without deeepspeed:
from safetensors.torch import save_file
from diffusers import FluxPipeline
def convett(input_dir, sava_path):
lora_state_dict = FluxPipeline.lora_state_dict(input_dir)
transformer_state_dict = {
f'{k.replace("module.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
save_file(transformer_state_dict, sava_path)
input_dir = "pytorch_lora_weights.safetensors"
sava_path = "out.safetensors"
convett(input_dir, sava_path)
Yeah of course that is why I suggested. Usually, you would want to always call unwrap_model() here
https://github.com/huggingface/diffusers/blob/45aa8bb1877272631ac6721bac9d04ed23372651/examples/dreambooth/train_dreambooth_lora_flux.py#L1205
and here
https://github.com/huggingface/diffusers/blob/45aa8bb1877272631ac6721bac9d04ed23372651/examples/dreambooth/train_dreambooth_lora_flux.py#L1207
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
Is this still a problem?
no problem
same error, and i have tried this modification
if isinstance(unwrap_model(model), type(unwrap_model(transformer)))inplaceif isinstance(model, type(unwrap_model(transformer)))
but it does not work
Possible to try out https://github.com/a-r-r-o-w/cogvideox-factory/blob/0affacb2296027fc40a6f3900ce9157b4f4ea46d/training/cogvideox_image_to_video_lora.py#L382 ?
I have the same issue. Cannot save the lora adapters with deepspeed. It times out and then errors out.
What command are you using? What is your DeepSpeed config?
Deepspeed Config
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
gradient_accumulation_steps: 4
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: false
zero3_save_16bit_model: false
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'yes'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: 'bf16'
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
Accelerate Command
export MODEL_NAME="black-forest-labs/FLUX.1-dev"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="trained-flux-lora-081025-0735"
accelerate launch --config_file /home/ubuntu/flux-fine-tune/default_config.yaml train_dreambooth_lora_flux.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--mixed_precision="bf16" \
--instance_prompt="a photo of sks dog" \
--resolution=512 \
--train_batch_size=1 \
--guidance_scale=1 \
--gradient_accumulation_steps=4 \
--optimizer="prodigy" \
--learning_rate=1. \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=25 \
--seed="0"
This is the timeout error
8:25:22 - INFO - accelerate.accelerator - Saving current state to trained-flux-lora/checkpoint-500
08/10/2025 18:25:22 - INFO - accelerate.accelerator - Saving DeepSpeed Model and Optimizer
[2025-08-10 18:25:22,429] [INFO] [logging.py:107:log_dist] [Rank 0] [Torch] Checkpoint pytorch_model is begin to save!
[rank0]:[E810 18:55:22.997053236 ProcessGroupNCCL.cpp:685] [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=502, OpType=ALLREDUCE, NumelIn=20, NumelOut=20, Timeout(ms)=1800000) ran for 1800074 milliseconds before timing out.
[rank0]:[E810 18:55:22.997155498 ProcessGroupNCCL.cpp:2252] [PG ID 0 PG GUID 0(default_pg) Rank 0] failure detected by watchdog at work sequence id: 502 PG status: last enqueued work: 503, last completed work: 501
[rank0]:[E810 18:55:22.997166876 ProcessGroupNCCL.cpp:732] Stack trace of the failed collective not found, potentially because FlightRecorder is disabled. You can enable it by setting TORCH_NCCL_TRACE_BUFFER_SIZE to a non-zero value.
[rank0]:[E810 18:55:22.997195998 ProcessGroupNCCL.cpp:2584] [PG ID 0 PG GUID 0(default_pg) Rank 0] First PG on this rank to signal dumping.
[rank0]:[E810 18:55:22.174117086 ProcessGroupNCCL.cpp:1870] [PG ID 0 PG GUID 0(default_pg) Rank 0] Received a dump signal due to a collective timeout from this local rank and we will try our best to dump the debug info. Last enqueued NCCL work: 503, last completed NCCL work: 501.This is most likely caused by incorrect usages of collectives, e.g., wrong sizes used across ranks, the order of collectives is not same for all ranks or the scheduled collective, for some reason, didn't run. Additionally, this can be caused by GIL deadlock or other reasons such as network errors or bugs in the communications library (e.g. NCCL), etc.
[rank0]:[E810 18:55:22.174327727 ProcessGroupNCCL.cpp:1589] [PG ID 0 PG GUID 0(default_pg) Rank 0] ProcessGroupNCCL preparing to dump debug info. Include stack trace: 1
[rank1]:[E810 18:55:22.440414633 ProcessGroupNCCL.cpp:1806] [PG ID 0 PG GUID 0(default_pg) Rank 1] Observed flight recorder dump signal from another rank via TCPStore.
[rank7]:[E810 18:55:22.440427955 ProcessGroupNCCL.cpp:1806] [PG ID 0 PG GUID 0(default_pg) Rank 7] Observed flight recorder dump signal from another rank via TCPStore.
[rank1]:[E810 18:55:22.440509834 ProcessGroupNCCL.cpp:1870] [PG ID 0 PG GUID 0(default_pg) Rank 1] Received a dump signal due to a collective timeout from rank 0 and we will try our best to dump the debug info. Last enqueued NCCL work: 501, last completed NCCL work: 501.This is most likely caused by incorrect usages of collectives, e.g., wrong sizes used across ranks, the order of collectives is not same for all ranks or the scheduled collective, for some reason, didn't run. Additionally, this can be caused by GIL deadlock or other reasons such as network errors or bugs in the communications library (e.g. NCCL), etc.
[rank7]:[E810 18:55:22.440512888 ProcessGroupNCCL.cpp:1870] [PG ID 0 PG GUID 0(default_pg) Rank 7] Received a dump signal due to a collective timeout from rank 0 and we will try our best to dump the debug info. Last enqueued NCCL work: 501, last completed NCCL work: 501.This is most likely caused by incorrect usages of collectives, e.g., wrong sizes used across ranks, the order of collectives is not same for all ranks or the scheduled collective, for some reason, didn't run. Additionally, this can be caused by GIL deadlock or other reasons such as network errors or bugs in the communications library (e.g. NCCL), etc.
[rank1]:[E810 18:55:22.440605721 ProcessGroupNCCL.cpp:1589] [PG ID 0 PG GUID 0(default_pg) Rank 1] ProcessGroupNCCL preparing to dump debug info. Include stack trace: 1
[rank7]:[E810 18:55:22.440612553 ProcessGroupNCCL.cpp:1589] [PG ID 0 PG GUID 0(default_pg) Rank 7] ProcessGroupNCCL preparing to dump debug info. Include stack trace: 1
[2025-08-10 18:56:22,515] [WARNING] [engine.py:3326:_checkpoint_tag_validation] [rank=0] The checkpoint tag name 'pytorch_model' is not consistent across all ranks. Including rank unique information in checkpoint tag could cause issues when restoring with different world sizes.
[rank0]:[E810 18:56:23.632492219 ProcessGroupNCCL.cpp:746] [Rank 0] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[rank0]:[E810 18:56:23.632516737 ProcessGroupNCCL.cpp:760] [Rank 0] To avoid data inconsistency, we are taking the entire process down.
[rank0]:[E810 18:56:23.633473872 ProcessGroupNCCL.cpp:2068] [PG ID 0 PG GUID 0(default_pg) Rank 0] Process group watchdog thread terminated with exception: [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=502, OpType=ALLREDUCE, NumelIn=20, NumelOut=20, Timeout(ms)=1800000) ran for 1800074 milliseconds before timing out.
Exception raised from checkTimeout at /pytorch/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:688 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x80 (0x7fb11ad7eeb0 in /home/ubuntu/flux-fine-tune/.venv/lib/python3.12/site-packages/torch/lib/libc10.so)
frame #1: c10d::ProcessGroupNCCL::WorkNCCL::checkTimeout(std::optional<std::chrono::duration<long, std::ratio<1l, 1000l> > >) + 0x247 (0x7fb0bda40147 in /home/ubuntu/flux-fine-tune/.venv/lib/python3.12/site-packages/torch/lib/libtorch_cuda.so)
frame #2: c10d::ProcessGroupNCCL::Watchdog::runLoop() + 0x1591 (0x7fb0bda43b61 in /home/ubuntu/flux-fine-tune/.venv/lib/python3.12/site-packages/torch/lib/libtorch_cuda.so)
frame #3: c10d::ProcessGroupNCCL::Watchdog::run() + 0xd2 (0x7fb0bda44ec2 in /home/ubuntu/flux-fine-tune/.venv/lib/python3.12/site-packages/torch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0xecdb4 (0x7fb11daecdb4 in /lib/x86_64-linux-gnu/libstdc++.so.6)
frame #5: <unknown function> + 0x9caa4 (0x7fb120a9caa4 in /lib/x86_64-linux-gnu/libc.so.6)
frame #6: <unknown function> + 0x129c3c (0x7fb120b29c3c in /lib/x86_64-linux-gnu/libc.so.6)
terminate called after throwing an instance of 'c10::DistBackendError'
Thanks for the investigations and sorry about my delay.
Could you check if the following patch works for you?
diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py
index 6ec532e63..658a68f70 100644
--- a/examples/dreambooth/train_dreambooth_lora_flux.py
+++ b/examples/dreambooth/train_dreambooth_lora_flux.py
@@ -44,7 +44,7 @@ import torch
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
-from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
+from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed, DistributedType
from huggingface_hub import create_repo, upload_folder
from huggingface_hub.utils import insecure_hashlib
from peft import LoraConfig, set_peft_model_state_dict
@@ -1280,29 +1280,31 @@ def main(args):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
- if accelerator.is_main_process:
- transformer_lora_layers_to_save = None
- text_encoder_one_lora_layers_to_save = None
- modules_to_save = {}
- for model in models:
- if isinstance(model, type(unwrap_model(transformer))):
- transformer_lora_layers_to_save = get_peft_model_state_dict(model)
- modules_to_save["transformer"] = model
- elif isinstance(model, type(unwrap_model(text_encoder_one))):
- text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
- modules_to_save["text_encoder"] = model
- else:
- raise ValueError(f"unexpected save model: {model.__class__}")
+ transformer_lora_layers_to_save = None
+ text_encoder_one_lora_layers_to_save = None
+ modules_to_save = {}
+ for model in models:
+ if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
+ model = unwrap_model(model)
+ transformer_lora_layers_to_save = get_peft_model_state_dict(model)
+ modules_to_save["transformer"] = model
+ elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))):
+ model = unwrap_model(model)
+ text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
+ modules_to_save["text_encoder"] = model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
- # make sure to pop weight so that corresponding model is not saved again
+ # make sure to pop weight so that corresponding model is not saved again
+ if weights:
weights.pop()
- FluxPipeline.save_lora_weights(
- output_dir,
- transformer_lora_layers=transformer_lora_layers_to_save,
- text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
- **_collate_lora_metadata(modules_to_save),
- )
+ FluxPipeline.save_lora_weights(
+ output_dir,
+ transformer_lora_layers=transformer_lora_layers_to_save,
+ text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
+ **_collate_lora_metadata(modules_to_save),
+ )
def load_model_hook(models, input_dir):
transformer_ = None
@@ -1311,12 +1313,12 @@ def main(args):
while len(models) > 0:
model = models.pop()
- if isinstance(model, type(unwrap_model(transformer))):
- transformer_ = model
- elif isinstance(model, type(unwrap_model(text_encoder_one))):
- text_encoder_one_ = model
+ if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
+ transformer_ = unwrap_model(model)
+ elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))):
+ text_encoder_one_ = unwrap_model(model)
else:
- raise ValueError(f"unexpected save model: {model.__class__}")
+ raise ValueError(f"unexpected save model: {unwrap_model(model).__class__}")
lora_state_dict = FluxPipeline.lora_state_dict(input_dir)
@@ -1842,7 +1844,7 @@ def main(args):
progress_bar.update(1)
global_step += 1
- if accelerator.is_main_process:
+ if accelerator.distributed_type == DistributedType.DEEPSPEED or accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None: