guidance icon indicating copy to clipboard operation
guidance copied to clipboard

add quantization parameters for transformers models

Open jquesnelle opened this issue 2 years ago • 6 comments

This adds the torch_dtype and load_in_8bit quantization parameters for transformers models that allow loading larger models with lower VRAM requirements.

The Llama and MPT subclasses load the model directly, so the parameters had to duplicated there. MPT doesn't seem to support load_in_8bit yet, so I just included torch_dtype for that one

jquesnelle avatar May 16 '23 15:05 jquesnelle

@microsoft-github-policy-service agree

jquesnelle avatar May 16 '23 15:05 jquesnelle

Thanks! will take a look soon

slundberg avatar May 16 '23 23:05 slundberg

Not sure how maintainers feel but i'd appreciate a general kwargs input for transformers models

evanmays avatar May 18 '23 21:05 evanmays

I also think a general kwargs is the right approach. It seems that we can just pass the same kwargs to both the tokenizer and the model loader for most params. @jquesnelle mind updating this PR with a general kwargs?

slundberg avatar May 20 '23 04:05 slundberg

Yup, no problem! Will have it updated shortly 🙂

jquesnelle avatar May 20 '23 16:05 jquesnelle

Should be all set @slundberg

Just a little bit ugly since LLaMA and MPT load the models themselves it required duplicating the arguments to Transformers.__init__ in the __init__ for these subclasses.

jquesnelle avatar May 21 '23 02:05 jquesnelle

Thanks @jquesnelle ! I agree it is a bit redundant. After I merge I'll fix that by making LLaMA and MPT override the _model_and_tokenizer method instead of init.

slundberg avatar May 22 '23 22:05 slundberg