save_state() fails when using DeepSpeed with multiple models where some models are frozen some are being trained
System Info
- `Accelerate` version: 0.35.0.dev0
- Platform: Linux-5.10.0-32-cloud-amd64-x86_64-with-glibc2.31
- `accelerate` bash location: /opt/conda/envs/flux_cn_exp/bin/accelerate
- Python version: 3.10.10
- Numpy version: 1.26.4
- PyTorch version (GPU?): 2.3.0+cu121 (True)
- PyTorch XPU available: False
- PyTorch NPU available: False
- PyTorch MLU available: False
- PyTorch MUSA available: False
- System RAM: 334.40 GB
- GPU type: NVIDIA A100-SXM4-80GB
- `Accelerate` default config:
- compute_environment: LOCAL_MACHINE
- distributed_type: MULTI_GPU
- mixed_precision: bf16
- use_cpu: False
- debug: False
- num_processes: 2
- machine_rank: 0
- num_machines: 1
- gpu_ids: all
- rdzv_backend: static
- same_network: True
- main_training_function: main
- enable_cpu_affinity: False
- downcast_bf16: no
- tpu_use_cluster: False
- tpu_use_sudo: False
- tpu_env: []
Information
- [X] The official example scripts
- [X] My own modified scripts
Tasks
- [X] One of the scripts in the examples/ folder of Accelerate or an officially supported
no_trainerscript in theexamplesfolder of thetransformersrepo (such asrun_no_trainer_glue.py) - [ ] My own task or dataset (give details below)
Reproduction
- Just apply the following patch (3 lines are added) on the
src/accelerate/test_utils/scripts/external_deps/test_ds_multiple_model.py - Run the test script.
diff --git a/src/accelerate/test_utils/scripts/external_deps/test_ds_multiple_model.py b/src/accelerate/test_utils/scripts/external_deps/test_ds_multiple_model.py
index 3729ecf..73d937f 100644
--- a/src/accelerate/test_utils/scripts/external_deps/test_ds_multiple_model.py
+++ b/src/accelerate/test_utils/scripts/external_deps/test_ds_multiple_model.py
@@ -181,6 +181,10 @@ def single_model_training(config, args):
# Use accelerator.print to print only on the main process.
accelerator.print(f"epoch {epoch}:", eval_metric)
performance_metric[f"epoch-{epoch}"] = eval_metric["accuracy"]
+
+ accelerator.print("Saving weights...")
+ accelerator.save_state(f"./epoch-{epoch}")
+ accelerator.print("Weights are saved...")
if best_performance < eval_metric["accuracy"]:
best_performance = eval_metric["accuracy"]
Error Message:
[rank0]: File "/opt/conda/envs/flux_cn_exp/lib/python3.10/site-packages/accelerate/accelerator.py", line 3039, in save_state
[rank0]: model.save_checkpoint(output_dir, ckpt_id, **save_model_func_kwargs)
[rank0]: File "/opt/conda/envs/flux_cn_exp/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 3136, in save_checkpoint
[rank0]: self.checkpoint_engine.makedirs(save_dir, exist_ok=True)
[rank0]: AttributeError: 'NoneType' object has no attribute 'makedirs'
This happens for frozen inference models. Their optimizer type is DummyOptim and as a return they are initialized with DeepSpeedZeRoOffload. As a result, checkpoint_engine is not assigned for them here and it is None:
https://github.com/microsoft/DeepSpeed/blob/8cded575a94e296fee751072e862304676c95316/deepspeed/runtime/engine.py#L340
if not isinstance(self.optimizer, DeepSpeedZeRoOffload):
self._configure_checkpointing(dist_init_required)
Expected behavior
The accelerator should handle the cases when the self._models contains frozen models here.
cc: @muellerzr
Thanks, definitely will try and take a look at it!
load_state() also fails when resuming from a checkpoint:
[rank3]: File "/opt/conda/envs/flux_cn/lib/python3.10/site-packages/deepspeed/runtime/state_dict_factory.py", line 168, in check_ckpt_list
[rank3]: assert len(self.ckpt_list) > 0
I am currently circumventing the issues by wrapping the load_state() & save_state() calls as follows:
acc_models = accelerator._models
accelerator._models = [model for model in acc_models if model.checkpoint_engine is not None]
#Â <load_state() or save_state() here>
accelerator._models = acc_models
same question.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
.
same question met in the same scripts, we need to reopen this issue I guess
Thanks, definitely will try and take a look at it!
The issue still exists, and the code that reproduces the problem is even the official example (it doesn't save, but once saving is added, the issue occurs).
Same question. Issue still exists.
+1 Can't save state in LoRA training with accelerate+deepspeed