qlora icon indicating copy to clipboard operation
qlora copied to clipboard

4 bit model cannot be trained by Huggingface Trainer

Open zyzhang1130 opened this issue 2 years ago • 1 comments

when I train a quantized model in the following way:

model_id = "openlm-research/open_llama_3b_600bt_preview"
# model_id = "EleutherAI/gpt-neo-1.3B"
# model_id = "gpt2-xl"
qlora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
)

training_args = TrainingArguments(
        output_dir="/content/drive/MyDrive/Colab Notebooks/GPT_GAN/my_awesome_model",
        learning_rate=2e-5,
        per_device_train_batch_size=1,
        per_device_eval_batch_size=1,
        num_train_epochs=2,
        weight_decay=0.01,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        push_to_hub=False,
        fp16=True,
        optim="paged_adamw_8bit",
    )

    trainer = transformers.Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_combined_qna["train"],
        eval_dataset=tokenized_combined_qna["test"],
        tokenizer=tokenizer,
        # data_collator=data_collator,
        compute_metrics=compute_metrics,
    )

    trainer.train()

it gives the following error:

ValueError: `.to` is not supported for `4-bit` or `8-bit` models. Please use the model as it is, since the model 
has already been set to the correct devices and casted to the correct `dtype`.

is there a way to disable the default .to in transformers.Trainer?

zyzhang1130 avatar Jun 02 '23 14:06 zyzhang1130

It seems you're trying to train the original model, not the LoRA adapter. Based on the presence of qlora_config, I'm guessing you actually want to train the adapter. Try adding the following code after loading the model

from peft import prepare_model_for_kbit_training, get_peft_model

model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, qlora_config)

pmysl avatar Jun 02 '23 16:06 pmysl

it worked thank you so much!

zyzhang1130 avatar Jun 03 '23 04:06 zyzhang1130