DeepSpeedExamples
DeepSpeedExamples copied to clipboard
step1-sft use lora failed
env
gpu: 4*A100 80G
pytorch: 1.13.1
cuda version: 11.7
deepspeed: 0.9.0
transformers: 4.28.0.dev
run script
OUTPUT=$1
ZERO_STAGE=3
if [ "$OUTPUT" == "" ]; then
OUTPUT=./output
fi
if [ "$ZERO_STAGE" == "" ]; then
ZERO_STAGE=3
fi
mkdir -p $OUTPUT
deepspeed main.py \
--data_path path/to/local/data \
--model_name_or_path path/to/codegen-16B-multi \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--max_seq_len 2048 \
--learning_rate 1e-4 \
--weight_decay 0.1 \
--num_train_epochs 5 \
--gradient_accumulation_steps 1 \
--lr_scheduler_type cosine \
--num_warmup_steps 0 \
--seed 1234 \
--only_optimize_lora \
--zero_stage $ZERO_STAGE \
--lora_dim 128 \
--lora_module_name decoder.layers. \
--deepspeed \
--output_dir $OUTPUT
error message
Traceback (most recent call last):
File "/mnt/data/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py", line 328, in <module>
main()
File "/mnt/data/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py", line 273, in main
model, optimizer, _, lr_scheduler = deepspeed.initialize(
File "/mnt/data/anaconda3/envs/ds-chat/lib/python3.9/site-packages/deepspeed/__init__.py", line 156, in initialize
engine = DeepSpeedEngine(args=args,
File "/mnt/data/anaconda3/envs/ds-chat/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 328, in __init__
self._configure_optimizer(optimizer, model_parameters)
File "/mnt/data/anaconda3/envs/ds-chat/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 1187, in _configure_optimizer
self.optimizer = self._configure_zero_optimizer(basic_optimizer)
File "/mnt/data/anaconda3/envs/ds-chat/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 1465, in _configure_zero_optimizer
optimizer = DeepSpeedZeroOptimizer_Stage3(
File "/mnt/data/anaconda3/envs/ds-chat/lib/python3.9/site-packages/deepspeed/runtime/zero/stage3.py", line 133, in __init__
self.dtype = self.optimizer.param_groups[0]['params'][0].dtype
IndexError: list index out of range