FastChat icon indicating copy to clipboard operation
FastChat copied to clipboard

Missing import from train/train.py for LORA training

Open jeremycochoy opened this issue 2 years ago • 1 comments

The train_lora.py import a function smart_tokenizer_and_embedding_resize that was deleted in previous commit.

jeremycochoy avatar Apr 14 '23 01:04 jeremycochoy

You can refer https://github.com/GanjinZero/RRHF/blob/529196c00656322ce861fd8262a2c452b401780f/train.py#L93 to manually add this function

CiaoHe avatar Apr 15 '23 04:04 CiaoHe

Thats indeed what I have done but this seams to be insufficient to run the original LORA configuration. I was able to reproduce the original lora training from the original lora repository and it use barely half of the memory of a single 3090 GPU, however I wasn't able to reproduce this from FastChat configuration (OOM even with very low micro-batch and sequence length settings). I didn't had a deep look at the FastChat codebase however.

jeremycochoy avatar Apr 15 '23 22:04 jeremycochoy

https://github.com/lm-sys/FastChat/pull/441 is supposed to fix it, could please you have a try on the nightly fastchat?

ZYHowell avatar Apr 17 '23 15:04 ZYHowell

Thanks. I will have a look this evening and keep you updated 👍

jeremycochoy avatar Apr 17 '23 17:04 jeremycochoy

I tried the last head. The code do seams to run (i.e. what I got when I copy pasted the missing functions into the file) however I imediately get an OOM error with a RTX 3090, although I can perfectly run the original alpaca-lora repository training on this machine

OutOfMemoryError: CUDA out of memory. Tried to allocate 172.00 MiB (GPU 0; 23.69 GiB total capacity; 17.55 GiB already 
allocated; 132.06 MiB free; 17.55 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try 
setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 103711) of binary: /home/jupyter/llm/FastChat/.venv/bin/python3

Here is the command I use to run the fastChat lora training:

CUDA_LAUNCH_BLOCKING=1  torchrun   --nproc_per_node=1 --master_port=20001 fastchat/train/train_lora.py     --model_name_or_path ./llama-7b-exp/     --data_path playground/data/dummy.json     --bf16 True     --output_dir output     --num_train_epochs 3     --per_device_train_batch_size 1     --per_device_eval_batch_size 1     --gradient_accumulation_steps 4     --evaluation_strategy "no"     --save_strategy "steps"     --save_steps 1200     --save_total_limit 10     --learning_rate 2e-5     --weight_decay 0.     --warmup_ratio 0.03     --lr_scheduler_type "cosine"     --logging_steps 1     --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer'     --tf32 True     --model_max_length 256     --gradient_checkpointing True     --lazy_preprocess True

jeremycochoy avatar Apr 18 '23 09:04 jeremycochoy

If there is only one gpu, maybe you can directly run train_lora.py without FSDP(in case it's FS-Data-Parallel). Besides, as mentioned here, gradient checkpointing with LoRA needs a monkey patch on transformers' internal code. I met the same case that OOM is only after v1.1 and solved it by setting the activation after embedding requires grad in llama model's code.

ZYHowell avatar Apr 18 '23 13:04 ZYHowell

Is the issue resolved?

zhisbug avatar Apr 22 '23 02:04 zhisbug

If there is only one gpu, maybe you can directly run train_lora.py without FSDP(in case it's FS-Data-Parallel). Besides, as mentioned here, gradient checkpointing with LoRA needs a monkey patch on transformers' internal code. I met the same case that OOM is only after v1.1 and solved it by setting the activation after embedding requires grad in llama model's code.

Do you have a reference for the line in the llama model code you patched? :) Or even just a snippet off your local code diff, that would be enough for me to understand what you mean and how to reproduce.

jeremycochoy avatar Apr 26 '23 15:04 jeremycochoy

Please check this issue: https://github.com/lm-sys/FastChat/issues/581

ZYHowell avatar Apr 28 '23 03:04 ZYHowell