add quantization parameters for transformers models
This adds the torch_dtype and load_in_8bit quantization parameters for transformers models that allow loading larger models with lower VRAM requirements.
The Llama and MPT subclasses load the model directly, so the parameters had to duplicated there. MPT doesn't seem to support load_in_8bit yet, so I just included torch_dtype for that one
@microsoft-github-policy-service agree
Thanks! will take a look soon
Not sure how maintainers feel but i'd appreciate a general kwargs input for transformers models
I also think a general kwargs is the right approach. It seems that we can just pass the same kwargs to both the tokenizer and the model loader for most params. @jquesnelle mind updating this PR with a general kwargs?
Yup, no problem! Will have it updated shortly 🙂
Should be all set @slundberg
Just a little bit ugly since LLaMA and MPT load the models themselves it required duplicating the arguments to Transformers.__init__ in the __init__ for these subclasses.
Thanks @jquesnelle ! I agree it is a bit redundant. After I merge I'll fix that by making LLaMA and MPT override the _model_and_tokenizer method instead of init.