VILA icon indicating copy to clipboard operation
VILA copied to clipboard

Why finetuning NVILA-Lite-8B-stage2 needs so much memory and so slow??

Open goodstudent9 opened this issue 6 months ago • 2 comments

Hi, I am using NVILA-Lite-8B-stage2 to finetune on my downstream task.

The input has 8 images at most, 3 images at least.

But I found that 7*A100 with zero2 can't run it due to GPU OOM. Zero3 is work but it is too slow.

I am confused because I have finetuned many 7B models on this setting with zero2, the GPU utilization is OK.

Do you have any ideas how can I use your model to train my task?

Best wishes.

Following is my training scripts torchrun \ --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --node_rank=$NODE_RANK \ --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT \ llava/train/train_mem.py \ --deepspeed scripts/zero3.json \ --model_name_or_path $STAGE_PATH \ --data_mixture $DATA_MIXTURE \ --vision_tower checkpoint_org/paligemma-siglip-so400m-patch14-448 \ --mm_vision_select_feature cls_patch \ --mm_projector mlp_downsample_3x3_fix \ --tune_vision_tower True \ --tune_mm_projector True \ --tune_language_model True \ --mm_vision_select_layer -2 \ --mm_use_im_start_end False \ --mm_use_im_patch_token False \ --image_aspect_ratio dynamic \ --bf16 True \ --output_dir $OUTPUT_DIR/model \ --num_train_epochs 1 \ --per_device_train_batch_size 1 \ --gradient_accumulation_steps 8 \ --evaluation_strategy no \ --save_strategy steps \ --save_steps 100 \ --save_total_limit 1 \ --learning_rate 1e-5 \ --weight_decay 0. \ --warmup_ratio 0.03 \ --lr_scheduler_type cosine \ --logging_steps 1 \ --model_max_length 4096 \ --gradient_checkpointing True \ --dataloader_num_workers 16 \ --vflan_no_system_prompt True \ --report_to tensorboard

goodstudent9 avatar May 21 '25 06:05 goodstudent9

It took so long! I don't think that is normal situation. Image

goodstudent9 avatar May 21 '25 09:05 goodstudent9

Hi ~ Do you solve this problem? I wanted to fine-tune 8B model with my dataset (The input has 8 images ), but failed because OOM even bs = 1 & zero3.

return-sleep avatar Aug 22 '25 02:08 return-sleep