lit-llama icon indicating copy to clipboard operation
lit-llama copied to clipboard

How to train 13B version on 8bit with LoRA

Open raj-khare opened this issue 1 year ago • 2 comments

I want to train the 13B Lllama but with 8bit quantization LoRA. Rn it takes 70GB of GPU RAM which is quite a lot. I'm using 8xA100-80GB.

lora.py

# Hyperparameters
learning_rate = 3e-4
batch_size = 64
micro_batch_size = 1
gradient_accumulation_iters = batch_size // micro_batch_size
assert gradient_accumulation_iters > 0
max_iters = 50000 * 3 // micro_batch_size
weight_decay = 0.0
max_seq_length = 4096  # see scripts/prepare_alpaca.py
lora_r = 8
lora_alpha = 16
lora_dropout = 0.05
warmup_iters = 100
image
def main(
    data_dir: str = "dataset", 
    pretrained_path: str = "/scratch/checkpoints/lit-llama/13B/lit-llama.pth",
    tokenizer_path: str = "/scratch/checkpoints/lit-llama/tokenizer.model",
    out_dir: str = "out/lora",
):

    fabric = L.Fabric(accelerator="cuda", devices=8, precision="bf16-true")
    fabric.launch()
    fabric.seed_everything(1337 + fabric.global_rank)
    ...

raj-khare avatar Jun 30 '23 19:06 raj-khare

Training can't be done with quantized weights, as the steps would fall within the quantization error threshold.

snake-4 avatar Jul 17 '23 10:07 snake-4

Can I quantize and finetuned an llm of bf16 with qlora 4bit?.

AjibolaPy avatar May 02 '24 12:05 AjibolaPy