Llama3 qlora load_state_dict takes forever
When working on my customized LoRAFinetuneRecipeSingleDevice recipe and upgrading from torchtune version 0.1.1 to 0.2.1 and torchao 0.1 to 0.3.1, I noticed that model loading times went up dramatically when using QLoRA. Now, loading llama3-8b takes about 5 minutes, where it used to only be a few seconds in version 0.1.1. I was able to pinpoint it to the call model.load_state_dict(base_model_state_dict, strict=False).
Here are the steps I took to reproduce this issue in a new conda environment (pytorch version 2.4).
conda create --name tune python=3.11
conda install pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia
conda activate tune
pip install torchtune
tune download meta-llama/Meta-Llama-3-8B-Instruct --output-dir . --hf-token <token>
tune run lora_finetune_single_device --config llama3/8B_qlora_single_device checkpointer.checkpoint_dir=$CHECKPOINT_DIR tokenizer.path=$CHECKPOINT_DIR/tokenizer.model checkpointer.output_dir=$CHECKPOINT_DIR output_dir=$CHECKPOINT_DIR
outputs
...
DEBUG:torchtune.utils.logging:Setting manual seed to local seed 4061266822. Local seed is seed + rank = 4061266822 + 0
Writing logs to /.../testtune/instruct/original/log_1722353553.txt
<WAITING FOR 5 MINUTES, RAM usage slowly rising to 15GB>
INFO:torchtune.utils.logging:Model is initialized with precision torch.bfloat16.
INFO:torchtune.utils.logging:Memory stats after model init:
GPU peak memory allocation: 6.79 GB
GPU peak memory reserved: 6.97 GB
GPU peak memory active: 6.79 GB
...
When I downgrade to version 0.1.1, it's fast again:
pip install torchtune==0.1.1
tune run lora_finetune_single_device --config llama3/8B_qlora_single_device checkpointer.checkpoint_dir=$CHECKPOINT_DIR tokenizer.path=$CHECKPOINT_DIR/tokenizer.model checkpointer.output_dir=$CHECKPOINT_DIR output_dir=$CHECKPOINT_DIR
outputs
DEBUG:torchtune.utils.logging:Setting manual seed to local seed 1435681341. Local seed is seed + rank = 1435681341 + 0
Writing logs to /.../testtune/instruct/original/log_1722354102.txt
INFO:torchtune.utils.logging:Model is initialized with precision torch.bfloat16.
INFO:torchtune.utils.logging:Memory Stats after model init:
{'peak_memory_active': 10.129489408, 'peak_memory_alloc': 10.129489408, 'peak_memory_reserved': 12.639535104}
after only a few seconds.
I am on a slurm machine with 4 cpu threads and a nvidia 3090 24gb.
Any ideas on what might be the cause? The lora recipes without quantization work just fine.
Not a torchtune author/contributor, but from the memory usage, I'm guessing that the old version performs NF4 quantization on GPU, while the new version performs it on CPU.
Not a torchtune author/contributor, but from the memory usage, I'm guessing that the old version performs NF4 quantization on GPU, while the new version performs it on CPU.
Makes sense, this was suggested by @msaroufim, too. I will confirm.
@l-berg Apologies for the late response - do you notice the slowdown on 0.2.0 as well? This will help me narrow down where these changes could be coming from.
Yes, upgrading from 0.1.1 to 0.2.0 results in the same increase from ~10s to >5min loading time on my machine.
Hi @l-berg - thanks for bringing this to our attention! The AO folks dug deep into this and saw that a version guarded inplace_copy function was the offending issue. Please read more about it here: https://github.com/pytorch/ao/issues/642.
This will be fixed by @ebsmothers in #1294