FastChat
FastChat copied to clipboard
Support specifying `revision` in `load_model`
I've added support for the revision parameter in load_model and load_compress_model. It explicitly defaults to "main", which is also the default in Huggingface from_pretrained methods. I believe all of the changes are backwards compatible.
I didn't touch the model loading part in ModelWorker because the distinction of model_path and model_names was unclear to me. Please advise.
Why are these changes needed?
A couple hours ago mosaicml/mpt-7b-chat introduced a commit (f9cc150) in its main branch that broke inference. (Raises RuntimeError: expected scalar type Half but found Float on my A40 GPU machine.) Reverting back to the previous commit (revision) solves the issue for now.
Even without this incident, I do believe that there is sufficient reason to allow users to fix the revision of their Hugging Face pretrained models.
I think it'll also make sense for the Vicuna Leaderboard, since you don't want model weights to silently update. It might be better if you post the revision IDs alongside model names so that leaderboard results are clearer and more reproducible.
Checks
- [x] I've run
format.shto lint the changes in this PR. ==>pylintscreams horribly and I think CI is just runningblack. So I also just ranblack. - [x] I've included any doc changes needed. ==> I don't think any was needed, but please correct me if I was wrong.
- [x] I've made sure the relevant tests are passing (if applicable). ==> I don't think there are relevant tests, but I tried the following commands manually and they worked as expected.
python -m fastchat.serve.cli --model-path mosaicml/mpt-7b-chat(Dies with dtype error)python -m fastchat.serve.cli --model-path mosaicml/mpt-7b-chat --revision main(Dies with dtype error)python -m fastchat.serve.cli --model-path mosaicml/mpt-7b-chat --revision bb4873bde98b60ef1c40d4c6c9729fe95de7dcbfpython -m fastchat.serve.cli --model-path mosaicml/mpt-7b-chat --revision bb4873bde98b60epython -m fastchat.serve.cli --model-path databricks/dolly-v2-12bpython -m fastchat.serve.cli --model-path lmsys/fastchat-t5-3b-v1.0python -m fastchat.serve.cli --model-path weights/llama/7Bpython -m fastchat.serve.cli --model-path weights/vicuna/13Bpython -m fastchat.serve.cli --model-path weights/vicuna/13B --load-8bit