petals icon indicating copy to clipboard operation
petals copied to clipboard

Support saving and loading 8-bit block weights

Open mryab opened this issue 1 year ago • 1 comments

This PR relies on https://github.com/TimDettmers/bitsandbytes/pull/159 and makes it possible to call convert_model with the int8 data type and later on download the 8-bit checkpoint instead of 16-bit if serving the model with load_in_8bit=True. This can save up to 2x bandwidth on starting a server, as shown by this comparison of model sizes for bloom-560m:

~/petals$ du -sh converted_model*
802M    converted_model
515M    converted_model_int8

The command that was used for conversion is python -m petals.cli.convert_model --model bigscience/bloom-560m --output_path ./converted_model_int8 --torch_dtype int8 --resize_token_embeddings 50000 --block_branch_prefix int8_block. To test that the checkpoint loads correctly, you need to install bitsandbytes from the branch in the PR above and run python -m petals.cli.run_server bigscience/test-bloomd --new_swarm --skip_reachability_check --throughput 100 --device cuda (pay attention that I had to change BLOCK_BRANCH_PREFIX in this branch for the sake of testing).

mryab avatar Feb 25 '23 15:02 mryab