sd-scripts
sd-scripts copied to clipboard
DeepSpeed Accelerator can NOT save optimizer state
Same as title. ~I'm in investigating.~ With following modification, scripts can save optimizer state only for small dataset, IDK why it does not work big dataset.
Error messages will be attached here.
- Environment
- 4 x RTX 3090 GPUs
- sd-scripts:
71e2c91330a9d866ec05cdd10584bbb962896a99
- ZeRO stage 1
- Using
train_network.py
normally
logs of saving optimizer state
INFO Saving DeepSpeed Model and Optimizer logging.py:61
[rank1]:[E ProcessGroupNCCL.cpp:523] [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=6004, OpType=ALLREDUCE, NumelIn=126834688, NumelOut=126834688, Timeout(ms)=600000) ran for 600410 milliseconds before timing out.
[rank2]:[E ProcessGroupNCCL.cpp:523] [Rank 2] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=6004, OpType=ALLREDUCE, NumelIn=126834688, NumelOut=126834688, Timeout(ms)=600000) ran for 600771 milliseconds before timing out.
[rank3]:[E ProcessGroupNCCL.cpp:523] [Rank 3] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=6004, OpType=ALLREDUCE, NumelIn=126834688, NumelOut=126834688, Timeout(ms)=600000) ran for 600333 milliseconds before timing out.
[rank1]:[E ProcessGroupNCCL.cpp:537] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[rank1]:[E ProcessGroupNCCL.cpp:543] To avoid data inconsistency, we are taking the entire process down.
Here is corresponding codes in sd-scripts. This structure is same as save_every_n_epoch.
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
accelerator.wait_for_everyone()
**if accelerator.is_main_process**:
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch)
**if args.save_state:**
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
Analysis: Only Rank 0(GPU 0, or cuda:0) is ready to save optimizer states. With ZeRO, stage above 0, optimizer state is distributed cross all gpus. So, in the block of is_main_process, accelerator wait forever rest of gpus(rank1, 2, 3) which never try to save optimizer state. Therefore, NCLL group raise timeout error. Of course, saving model is not a problem.
Related issue: get stuck when save_state using DeepSpeed backend under training train_text_to_image_lora
- Applying related issue' method
make save state line out of is_main_process block.
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch)
if args.save_state and not args.deepspeed:
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
if args.save_state and args.deepspeed:
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
accelerator.wait_for_everyone()
This is ad-hoc modification for solving problem. And here are log of codes.
[2024-04-09 15:39:37,942] [INFO] [logging.py:96:log_dist] [Rank 0] Saving model checkpoint: ./training/model/temp/test-step00000001-state/pytorch_model/mp_rank_00_model_states.pt
[2024-04-09 15:39:37,942] [INFO] [torch_checkpoint_engine.py:21:save] [Torch] Saving ./training/model/temp/test-step00000001-state/pytorch_model/mp_rank_00_model_states.pt...
[2024-04-09 15:40:06,463] [INFO] [torch_checkpoint_engine.py:23:save] [Torch] Saved ./training/model/temp/test-step00000001-state/pytorch_model/mp_rank_00_model_states.pt.
[2024-04-09 15:40:07,083] [INFO] [torch_checkpoint_engine.py:21:save] [Torch] Saving ./training/model/temp/test-step00000001-state/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt...
[2024-04-09 15:40:07,246] [INFO] [torch_checkpoint_engine.py:21:save] [Torch] Saving ./training/model/temp/test-step00000001-state/pytorch_model/zero_pp_rank_1_mp_rank_00_optim_states.pt...
[2024-04-09 15:40:07,246] [INFO] [torch_checkpoint_engine.py:21:save] [Torch] Saving ./training/model/temp/test-step00000001-state/pytorch_model/zero_pp_rank_2_mp_rank_00_optim_states.pt...
[2024-04-09 15:40:07,249] [INFO] [torch_checkpoint_engine.py:21:save] [Torch] Saving ./training/model/temp/test-step00000001-state/pytorch_model/zero_pp_rank_3_mp_rank_00_optim_states.pt...
[2024-04-09 15:40:08,808] [INFO] [torch_checkpoint_engine.py:23:save] [Torch] Saved ./training/model/temp/test-step00000001-state/pytorch_model/zero_pp_rank_1_mp_rank_00_optim_states.pt.
[2024-04-09 15:40:08,808] [INFO] [engine.py:3477:_save_zero_checkpoint] zero checkpoint saved ./training/model/temp/test-step00000001-state/pytorch_model/zero_pp_rank_1_mp_rank_00_optim_states.pt
[2024-04-09 15:40:08,808] [INFO] [torch_checkpoint_engine.py:33:commit] [Torch] Checkpoint pytorch_model is ready now!
[2024-04-09 15:40:08,818] [INFO] [torch_checkpoint_engine.py:23:save] [Torch] Saved ./training/model/temp/test-step00000001-state/pytorch_model/zero_pp_rank_2_mp_rank_00_optim_states.pt.
[2024-04-09 15:40:08,818] [INFO] [engine.py:3477:_save_zero_checkpoint] zero checkpoint saved ./training/model/temp/test-step00000001-state/pytorch_model/zero_pp_rank_2_mp_rank_00_optim_states.pt
[2024-04-09 15:40:08,819] [INFO] [torch_checkpoint_engine.py:33:commit] [Torch] Checkpoint pytorch_model is ready now!
[2024-04-09 15:40:08,824] [INFO] [torch_checkpoint_engine.py:23:save] [Torch] Saved ./training/model/temp/test-step00000001-state/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt.
[2024-04-09 15:40:08,832] [INFO] [torch_checkpoint_engine.py:23:save] [Torch] Saved ./training/model/temp/test-step00000001-state/pytorch_model/zero_pp_rank_3_mp_rank_00_optim_states.pt.
[2024-04-09 15:40:08,832] [INFO] [engine.py:3477:_save_zero_checkpoint] zero checkpoint saved ./training/model/temp/test-step00000001-state/pytorch_model/zero_pp_rank_3_mp_rank_00_optim_states.pt
[2024-04-09 15:40:08,832] [INFO] [torch_checkpoint_engine.py:33:commit] [Torch] Checkpoint pytorch_model is ready now!
[2024-04-09 15:40:08,883] [INFO] [engine.py:3477:_save_zero_checkpoint] zero checkpoint saved ./training/model/temp/test-step00000001-state/pytorch_model/zero_pp_rank_0_mp_rank_00_optim_states.pt
[2024-04-09 15:40:08,884] [INFO] [torch_checkpoint_engine.py:33:commit] [Torch] Checkpoint pytorch_model is ready now!
2024-04-09 15:40:08 INFO DeepSpeed Model and Optimizer saved logging.py:61
to output dir
./training/model/temp/test-step000000
01-state/pytorch_model
Traceback (most recent call last):
File "/home/hard2251/workspace/sd-scripts/./sdxl_train_network.py", line 185, in <module>
trainer.train(args)
File "/home/hard2251/workspace/sd-scripts/train_network.py", line 949, in train
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
File "/home/hard2251/workspace/sd-scripts/library/train_util.py", line 4845, in save_and_remove_state_stepwise
accelerator.save_state(state_dir)
File "/home/hard2251/anaconda3/envs/sd-scripts/lib/python3.10/site-packages/accelerate/accelerator.py", line 2706, in save_state
hook(self._models, weights, output_dir)
File "/home/hard2251/workspace/sd-scripts/train_network.py", line 488, in save_model_hook
weights.pop(i)
IndexError: pop from empty list
Related issue: get stuck when save_state using DeepSpeed backend under training train_text_to_image_lora.
Before
def save_model_hook(models, weights, output_dir):
# pop weights of other models than network to save only network weights
if accelerator.is_main_process:
remove_indices = []
for i, model in enumerate(models):
if not isinstance(model, type(accelerator.unwrap_model(network))):
remove_indices.append(i)
for i in reversed(remove_indices):rr
weights.pop(i)
After
def save_model_hook(models, weights, output_dir):
# pop weights of other models than network to save only network weights
if accelerator.is_main_process:
remove_indices = []
for i, model in enumerate(models):
if not isinstance(model, type(accelerator.unwrap_model(network))):
remove_indices.append(i)
for i in reversed(remove_indices):
**if weights:**
weights.pop(i)
After modifying both block of save_state and function of save_model_hook, sd-scripts is ~now able to save optimizer state when deepspeed=true.~ able to save state when using small dataset.
Issue is updated!
By doing above problem, I met another error.
TypeError: cannot pickle 'torch._C._distributed_c10d.ProcessGroup' object
Same Environment, but
- ~training SDXL network~ When I changed smaller dataset, It works on sdxl network env
- Larger(~10K) dataset fails, but smaller(~50) dataset successes
Related Issue : deepspeed strategy can't save checkpoint, TypeError: cannot pickle torch._C._distributed_c10d.ProcessGroup object
I've updated the code as same as the issue https://github.com/huggingface/diffusers/issues/2606. However, I have no idea how to fix cannot pickle
error. The related issue seems to be still open.