Otter
Otter copied to clipboard
Question about training with FDSP config
Hi, thanks for your great work!
I have a question about training your model on the LADD split.
I have an A100 machine with 40GB VRAM. I use the following command for training:
export PYTHONPATH=.
accelerate launch --config_file=./pipeline/accelerate_configs/accelerate_config_fsdp.yaml \
pipeline/train/instruction_following.py \
--pretrained_model_name_or_path=luodian/OTTER-LLaMA7B-INIT \
--mimicit_path="./data/LADD_instructions.json" \
--images_path="./data/LA.json" \
--train_config_path="./data/LADD_train.json" \
--external_save_dir="./checkpoints" \
--batch_size=4 \
--num_epochs=9 \
--report_to_wandb \
--wandb_entity=vu27 \
--run_name=OTTER-LLaMA7B-TEST \
--wandb_project=OTTER-LLaMA7B-TEST \
--workers=8 \
--lr_scheduler=cosine \
--learning_rate=1e-5 \
--warmup_steps_ratio=0.01
I haven't changed the FDSP config, it looks like:
compute_environment: LOCAL_MACHINE
distributed_type: no
downcast_bf16: false
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: false
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
main_process_port: 20687
However when training on a single GPU, I run into OOM issues. I verified that the accelerate configs are working well (the training indeed is mixed precision). I saw that in one of your issues (https://github.com/Luodian/Otter/issues/182), you mentioned loading the model in lower precision, which I tried. I loaded the model, both with fp16 and bf16 but was wondering if for your training you used fp16/bf16 loading or loaded the model in full precision and relied on accelerate-autocasting? I am wondering because sometimes bf16 training can be more unstable compared to full-precision training (but we are also only training about 1/9th of the model parameters so I'm not sure if this would be an issue). Also wondering whether you were able to do training on a single GPU at all, or did you use multi-GPU setups for all your experiments and did not test single GPU runs at all? Also would you be able to provide more details on your final training run (batch-sizes used, mixed precision details, ddp/fsdp training, gradient accumulation/checkpointing, number of GPUs etc) so that I would have some reference guidelines for my training runs, your help would be much appreciated!
Another question: For fine-tuning from Otter, would you still recommend using the same training pipeline, or would recommend using PEFT methods (I guess you have implemented LoRA support already), do you have a guideline on what would be the best setup for fine-tuning Otter?
hi Vishaal, may I know if you are using one GPU with 40G to train? Can you decrease it to batch_size=1
or try withluodian/OTTER-MPT1B-RPJama-Init
?
hi Vishaal, may I know if you are using one GPU with 40G to train? Can you decrease it to
batch_size=1
or try withluodian/OTTER-MPT1B-RPJama-Init
?
I am not sure if one 40G GPU could work, I remember init the Otter-LLAMA7B would cost around 16G mem, and train with batch_size=1
is around 30-40G.
Thanks @Luodian , I was finally able to run it on a larger compute node with 4 40GB GPUs with a batch size of 32, but it still was strange that I couldn't test it on one GPU even with half precision. What ended up working though was loading the model in bf16 and using bf16 for mixed precision training. However, for fine-tuning would you still recommend this or would you recommend LoRA fine-tuning? Also do you have a sample train config for fine-tuning with LoRA?
We dont have a promising result for lora finetuning. We tried finetune perceiver + cross_x_attn + lora LLM. But dont get better results than finetuning perceiver + cross_x_attn.
If you want to lora finetune LLM, you should first convert otter init model to a lora version using https://github.com/Luodian/Otter/blob/main/otter/converting_otter_to_lora.py
Then you could directly load it without any other modification in training procedure.
If loaded with lora LLM, you could see relevant logs to show how many params are LoRAed, etc.