transformers
transformers copied to clipboard
Add quantization_config in AutoModelForCausalLM.from_config()
Feature request
Add quantization_config feature to AutoModelForCausalLM from config . I am trying to pretrain a model from scratch and use bits and bytes so that It can be trained on less computation expensive machines. Below is my quantization config :
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
When I attempted to take the config of certain model from_pretrained function it failed and raised a Type Error mentioned below.
from transformers import AutoConfig, AutoModelForCausalLM
config = AutoConfig.from_pretrained("mistralai/Mistral-7B-v0.1")
model = AutoModelForCausalLM.from_config(config,quantization_config=bnb_config, device_map={"":0})
The Error:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[23], line 7
3 # Download configuration from huggingface.co and cache.
5 configy = AutoConfig.from_pretrained("mistralai/Mistral-7B-v0.1")
----> 7 modely = AutoModelForCausalLM.from_config(configy,quantization_config=bnb_config, device_map={"":0})
File ~/miniconda3/envs/ai/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py:441, in _BaseAutoModelClass.from_config(cls, config, **kwargs)
439 elif type(config) in cls._model_mapping.keys():
440 model_class = _get_model_class(config, cls._model_mapping)
--> 441 return model_class._from_config(config, **kwargs)
443 raise ValueError(
444 f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
445 f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
446 )
File ~/miniconda3/envs/ai/lib/python3.10/site-packages/transformers/modeling_utils.py:1192, in PreTrainedModel._from_config(cls, config, **kwargs)
1190 model = cls(config, **kwargs)
1191 else:
-> 1192 model = cls(config, **kwargs)
1194 # restore default dtype if it was modified
1195 if dtype_orig is not None:
TypeError: MistralForCausalLM.__init__() got an unexpected keyword argument 'quantization_config'
Motivation
I had tried a work around by saving the model from the loaded config details from the model and then load the same model with quantization config .
I believe this process could get fixed and we can enable/add quantization while loading the model from the config itself.
Your contribution
from transformers import AutoConfig, AutoModelForCausalLM
config = AutoConfig.from_pretrained("mistralai/Mistral-7B-v0.1")
model = AutoModelForCausalLM.from_config(config)
model.save_pretrained(MODEL_NAME_PATH)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME_PATH, quantization_config=bnb_config, device_map={"":0})