LLaVA
LLaVA copied to clipboard
[Usage] Deepspeed Zero Stage 3 not able to shard the model
Hi @haotian-liu !
Interesting work around LLaVa!
Issue:
I am trying to finetune LLaVa using 8 X H100.
When I try to use DeepSpeed Zero Stage 3, it seems that the model gets replicated on all the GPUs, instead of being sharded. I get OOM issues when finetuning model. I am trying to use a context length of 2048 and ViT with 336 resolution.
Could you please suggest what I might be doing wrong here?
Command:
deepspeed llava/train/train_mem.py \
--deepspeed ./scripts/zero3.json \
--model_name_or_path ../$MODEL_VERSION \
--version $PROMPT_VERSION \
--data_path ./finetune_data/cleaned_finetune_data.json \
--image_folder ./finetune_data/images \
--vision_tower openai/clip-vit-large-patch14-336 \
--pretrain_mm_mlp_adapter ./checkpoints/llava-$MODEL_VERSION-pretrain/mm_projector.bin \
--mm_projector_type mlp2x_gelu \
--mm_vision_select_layer -2 \
--mm_use_im_start_end False \
--mm_use_im_patch_token False \
--image_aspect_ratio pad \
--group_by_modality_length True \
--bf16 True \
--output_dir ./checkpoints/llava-$MODEL_VERSION-finetune \
--num_train_epochs 1 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 16 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 500 \
--save_total_limit 1 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length 2048\
--gradient_checkpointing True \
--dataloader_num_workers 4 \
--lazy_preprocess True \
When I run the model using CUDA_VISIBLE_DEVICES=0 bash ./scripts/sample_stage3.sh
, the memory usage before training is:
However, when I am using the stage 3 deepspeed, the GPU usage before training is
And the model gets OOM after this. Could you please suggest what flag we might need to change?
Hi, I meet the same problem. Do you solve this problem?
I meet the same problem
If I use 2 H100 I can run the code but I get OOM. When I increase it to +2 GPUs the model duplicates on GPUs instead of sharding and gets stuck in Formatting inputs... Skip in lazy mode