Fix: try aligning dtype of matrixes when training with deepspeed and mixed-precision is set to bf16 or fp16
What problem is going to solve in this PR?
This PR is mainly trying fixing the problem subscribed in issue #1871
When I tried to do some training with scriptflux_train.py, I met the same error as the issue above.
When I removed --deepspeed and run with low train_batch_size which makes training slow.
Solution
I tried adding a wrapper in deepspeed_utils.py to wrap models' forward function with torch.autocast which provides convenience methods for mixed precision.
Changes detailed
- Add
__warp_with_torch_autocastto classDeepSpeedWrapperindeepspeed_utils.py. - Add
deepspeed==0.16.7requirements. - Do nothing when
accelerator.distributed_type == DistributedType.DEEPSPEEDin functionpatch_accelerator_for_fp16_trainingof scripttrain_util.pybecause deepspeed internally handles loss scaling for mixed precision training thenaccelerator.scalerwould be None which results in the same error as issue 476
After these changes, the dtype error disappeared and train_batch_size increased from 2(without deepspeed) to 12(with deepspeed and mixed-precision) running on 8x Nvidia A100 GPUs(80GB memory each) and get 17.54% speeding up with command as follow:
accelerate launch \
--num_cpu_threads_per_process=8 \
--multi_gpu \
--mixed_precision=fp16 \
--rdzv_backend=c10d \
"flux_train.py" \
--output_dir="output" \
--logging_dir="logs" \
--max_train_epochs=60 \
--learning_rate=2e-5 \
--output_name=flux_test \
--save_every_n_epochs=10 \
--save_precision=fp16 \
--seed=4242 \
--max_token_length=225 \
--caption_extension=.txt \
--vae_batch_size=4 \
--deepspeed \
--zero_stage=3 \
--ddp_timeout=120 \
--ddp_gradient_as_bucket_view \
--ddp_static_graph \
--mem_eff_save \
--clip_l="model/clip/ViT-L-14-TEXT-detail-improved-hiT-GmP-TE-only-HF.safetensors" \
--t5xxl="model/clip/t5xxl_fp16.safetensors" \
--apply_t5_attn_mask \
--discrete_flow_shift=3.185 \
--timestep_sampling=flux_shift \
--sigmoid_scale=1 \
--model_prediction_type=raw \
--guidance_scale=1 \
--ae="model/flux/ae.safetensors" \
--cache_text_encoder_outputs \
--cache_text_encoder_outputs_to_disk \
--sdpa \
--train_data_dir="data" \
--train_batch_size=12 \
--resolution=1024,1024 \
--enable_bucket \
--min_bucket_reso=256 \
--max_bucket_reso=2048 \
--bucket_no_upscale \
--pretrained_model_name_or_path="model/flux/flux1-dev.safetensors" \
--save_model_as=safetensors \
--clip_skip=2 \
--persistent_data_loader_workers \
--cache_latents \
--cache_latents_to_disk \
--gradient_checkpointing \
--use_8bit_adam \
--keep_tokens=1 \
--keep_tokens_separator="|||" \
--secondary_separator=";;;" \
--sample_every_n_epochs=200 \
--sample_sampler=euler_a \
--full_fp16 \
--mixed_precision=fp16 \
--gradient_accumulation_steps=1 \
--lr_scheduler=warmup_stable_decay \
--lr_scheduler_num_cycles=1 \
--lr_decay_steps=0.25 \
--lr_scheduler_min_lr_ratio=0.1
My questions are of the requirements. Where was deepspeed coming from before? Is updating to 2.6.0 and having diffusers automatically update it a good idea with various different ways you need to install torch for backend compatibility? Is the diffusers[torch] extra holding this back? Seems like it would be aligned in the same way with the torch version being the same version. Some environments are still only supporting torch 2.4 still so moving to the latest (2.6.0) as a requirement might cause some issues.
My questions are of the requirements. Where was deepspeed coming from before? Is updating to 2.6.0 and having diffusers automatically update it a good idea with various different ways you need to install torch for backend compatibility? Is the diffusers[torch] extra holding this back? Seems like it would be aligned in the same way with the torch version being the same version. Some environments are still only supporting torch 2.4 still so moving to the latest (2.6.0) as a requirement might cause some issues.
- I followed the installation steps with README of branch
sd3, but there is no deepspeed to be installed with commands
pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124
pip3 install -r requirements.txt # in branch sd3
I have tried installing deepspeed using accelerate[deepspeed]==0.33.0 in requirements.txt but there is no deepspeed being detected when launching training
2025-04-23 16:40:55 ERROR deepspeed is not installed. please install deepspeed in your deepspeed_utils.py:77
environment with following command. DS_BUILD_OPS=0 pip
install deepspeed
Then I add deepspeed==0.16.7 to requirements.txt, which works well for this PR and have no conflict to existing requirements.
- It seems my previous comment have caused some misunderstanding. My concern was that using
diffusers[torch]==0.25.0might automatically upgrade an existingtorch 2.4.0installation to the latest version. I’ve encountered this situation once I tried this repo with another environment, that's why I changed it todiffusers==0.25.0. However, since I can’t reproduce the issue now, I’ll change it back todiffusers[torch]==0.25.0.
Thank you for update!
Hi, but when I turn to mixed_precision=bf16, it still arises the [mat1 and mat2 must have the same dtype, but got Float and BFloat16] error. I am running script "flux_train_control_net.py" and the command is
accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_control_net.py --pretrained_model_name_or_path flux1-dev.safetensors --clip_l clip_l.safetensors --t5xxl t5xxl_fp16.safetensors --ae ae.safetensors --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 1 --seed 42 --gradient_checkpointing --mixed_precision bf16 --optimizer_type adamw8bit --learning_rate 2e-5 --highvram --max_train_epochs 1 --save_every_n_steps 1000 --output_dir /path/to/output --output_name flux-cn --timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0 --deepspeed --dataset_config dataset.toml --log_tracker_name "sd-scripts-flux-cn"
In addition, the machine only supports CUDA12.2, so I download pytorch2.4.0 with cu121 channel. Is it possible to cause that problem? pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu121