esm
esm copied to clipboard
EsmForSequenceClassification does not support gradient checkpointing
NOTE: if this is not a bug report, please use the GitHub Discussions for support questions (How do I do X?), feature requests, ideas, showcasing new applications, etc.
Bug description ESM-2 models do not seem to be compatible with QLoRA due to not being compatible with gradient checkpointing.
Reproduction steps Code to reproduce:
!pip install -q -U bitsandbytes
!pip install -q -U git+https://github.com/huggingface/transformers.git
!pip install -q -U git+https://github.com/huggingface/peft.git
!pip install -q -U git+https://github.com/huggingface/accelerate.git
!pip install -q datasets
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, BitsAndBytesConfig
model_id = "facebook/esm2_t6_8M_UR50D"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSequenceClassification.from_pretrained(model_id, quantization_config=bnb_config, device_map={"":0})
this next part produces the error:
from peft import prepare_model_for_kbit_training
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)
Expected behavior The script should simply prepare the model for training with a QLoRA (Quantized Low Rank Adaptation). See here for example which is linked to in this article.
Logs Please paste the command line output:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
[<ipython-input-2-d6b5f42e99b2>](https://localhost:8080/#) in <cell line: 3>()
1 from peft import prepare_model_for_kbit_training
2
----> 3 model.gradient_checkpointing_enable()
4 model = prepare_model_for_kbit_training(model)
[/usr/local/lib/python3.10/dist-packages/transformers/modeling_utils.py](https://localhost:8080/#) in gradient_checkpointing_enable(self)
1719 """
1720 if not self.supports_gradient_checkpointing:
-> 1721 raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
1722 self.apply(partial(self._set_gradient_checkpointing, value=True))
1723
ValueError: EsmForSequenceClassification does not support gradient checkpointing.
Additional context
This is a basic attempt at training a QLoRA for ESM-2 models such as facebook/esm2_t6_8M_UR50D
for a sequence classification task. The error is not task dependent though, and I have the same error when trying to train a token classifier. Any assistance on making ESM-2 models compatible with QLoRA would be greatly appreciated.