lit-llama
lit-llama copied to clipboard
How to train 13B version on 8bit with LoRA
I want to train the 13B Lllama but with 8bit quantization LoRA. Rn it takes 70GB of GPU RAM which is quite a lot. I'm using 8xA100-80GB.
lora.py
# Hyperparameters
learning_rate = 3e-4
batch_size = 64
micro_batch_size = 1
gradient_accumulation_iters = batch_size // micro_batch_size
assert gradient_accumulation_iters > 0
max_iters = 50000 * 3 // micro_batch_size
weight_decay = 0.0
max_seq_length = 4096 # see scripts/prepare_alpaca.py
lora_r = 8
lora_alpha = 16
lora_dropout = 0.05
warmup_iters = 100
def main(
data_dir: str = "dataset",
pretrained_path: str = "/scratch/checkpoints/lit-llama/13B/lit-llama.pth",
tokenizer_path: str = "/scratch/checkpoints/lit-llama/tokenizer.model",
out_dir: str = "out/lora",
):
fabric = L.Fabric(accelerator="cuda", devices=8, precision="bf16-true")
fabric.launch()
fabric.seed_everything(1337 + fabric.global_rank)
...
Training can't be done with quantized weights, as the steps would fall within the quantization error threshold.
Can I quantize and finetuned an llm of bf16 with qlora 4bit?.