CPM-Bee
CPM-Bee copied to clipboard
qlora
This PR mainly involves the following aspects:
-
QLoRA overall logic:
- First, quantize the model parameter files.
- Set the int4 field in the model's config to enable QLoRA fine-tuning.
- The rest is consistent with basic task fine-tuning.
-
Modifications to the model structure:
- Add a bool type field
int4in the model parameter files in the foldersrc/config, which acts as a switch to control whether to use QLoRA. Corresponding adjustments need to be made in other relevant structures (Attention/SelfAttentionBlock/FFNBlock/TransformerBlock/DenseGatedACT/FeedForward/Encoder/CPMBee) to load the appropriate models based on the int4 field. - In
src/cpm_live/layers/feedforward.py, add classLinear4bitas the QLoRA method linear layer; add classParams4bitas the weight forLinear4bit; add classDistributedParameter4Int8to meet encapsulation needs.
- Add a bool type field
-
Add scripts/sample code/README:
src/quantize_state_dict.pyis the code for compressing the initial weights. QLoRA needs to load the compressed dict as model weights.src/finetune_cpm_bee_qlora.pyis the fine-tuning sample code.src/scripts/finetune_cpm_bee_qlora.shis the fine-tuning sample script.tutorials/basic_task_finetune/README_qlora.mdis the fine-tuning tutorial for QLoRA.
-
Other considerations:
- The inspect part of the code has been commented out in
src/finetune_cpm_bee_qlora.py, asuint8does not supportstdandvar. - It's necessary to synchronize and modify the bug in
BMTrain.blocklayerwhereuint8typerequires_gradcannot be passed in.
- The inspect part of the code has been commented out in