petals
petals copied to clipboard
Support saving and loading 8-bit block weights
This PR relies on https://github.com/TimDettmers/bitsandbytes/pull/159 and makes it possible to call convert_model
with the int8 data type and later on download the 8-bit checkpoint instead of 16-bit if serving the model with load_in_8bit=True
. This can save up to 2x bandwidth on starting a server, as shown by this comparison of model sizes for bloom-560m:
~/petals$ du -sh converted_model*
802M converted_model
515M converted_model_int8
The command that was used for conversion is python -m petals.cli.convert_model --model bigscience/bloom-560m --output_path ./converted_model_int8 --torch_dtype int8 --resize_token_embeddings 50000 --block_branch_prefix int8_block
. To test that the checkpoint loads correctly, you need to install bitsandbytes from the branch in the PR above and run python -m petals.cli.run_server bigscience/test-bloomd --new_swarm --skip_reachability_check --throughput 100 --device cuda
(pay attention that I had to change BLOCK_BRANCH_PREFIX
in this branch for the sake of testing).