esm icon indicating copy to clipboard operation
esm copied to clipboard

EsmForSequenceClassification does not support gradient checkpointing

Open Amelie-Schreiber opened this issue 1 year ago • 1 comments

NOTE: if this is not a bug report, please use the GitHub Discussions for support questions (How do I do X?), feature requests, ideas, showcasing new applications, etc.

Bug description ESM-2 models do not seem to be compatible with QLoRA due to not being compatible with gradient checkpointing.

Reproduction steps Code to reproduce:

!pip install -q -U bitsandbytes
!pip install -q -U git+https://github.com/huggingface/transformers.git 
!pip install -q -U git+https://github.com/huggingface/peft.git
!pip install -q -U git+https://github.com/huggingface/accelerate.git
!pip install -q datasets
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, BitsAndBytesConfig

model_id = "facebook/esm2_t6_8M_UR50D"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSequenceClassification.from_pretrained(model_id, quantization_config=bnb_config, device_map={"":0})

this next part produces the error:

from peft import prepare_model_for_kbit_training

model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

Expected behavior The script should simply prepare the model for training with a QLoRA (Quantized Low Rank Adaptation). See here for example which is linked to in this article.

Logs Please paste the command line output:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
[<ipython-input-2-d6b5f42e99b2>](https://localhost:8080/#) in <cell line: 3>()
      1 from peft import prepare_model_for_kbit_training
      2 
----> 3 model.gradient_checkpointing_enable()
      4 model = prepare_model_for_kbit_training(model)

[/usr/local/lib/python3.10/dist-packages/transformers/modeling_utils.py](https://localhost:8080/#) in gradient_checkpointing_enable(self)
   1719         """
   1720         if not self.supports_gradient_checkpointing:
-> 1721             raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
   1722         self.apply(partial(self._set_gradient_checkpointing, value=True))
   1723 

ValueError: EsmForSequenceClassification does not support gradient checkpointing.

Additional context This is a basic attempt at training a QLoRA for ESM-2 models such as facebook/esm2_t6_8M_UR50D for a sequence classification task. The error is not task dependent though, and I have the same error when trying to train a token classifier. Any assistance on making ESM-2 models compatible with QLoRA would be greatly appreciated.

Amelie-Schreiber avatar Aug 24 '23 04:08 Amelie-Schreiber