FastChat
FastChat copied to clipboard
Add callback on save for LoRA
Why are these changes needed?
When we save the model checkpoint, it saves the entire model in pytorch_model.bin
which is extremely large, but we often only want the adapter model which is defined by an adapter_model.bin
and adapter_config.json
file. These changes create a callback such that we save the adapter model separately as well so that we can apply the LoRA model immediately using apply_lora.py
Related issue number (if applicable)
Fixes #1249
Checks
- [x] I've run
format.sh
to lint the changes in this PR. - [x] I've included any doc changes needed.
- [ ] I've made sure the relevant tests are passing (if applicable).
lgtm. Have you tested the new checkpoint size and reloading from the new checkpoint?
Yes, reloading from the new checkpoint still works, and the checkpoint size for the adapter_model
was 17M for me. I was able to test apply_lora
and it works with the adapter model that was produced.
I discovered and tested that the program (following this thread) will not create the pytorch_model.bin
file if we set "stage3_gather_16bit_weights_on_model_save": false
in ds_config.json
and the checkpoint resumes. So, now the memory is reduced in half since we don't save the pytorch_model.bin
file.
What I'm curious about is how the model even loads from the checkpoint without pytorch_model.bin
and whether that's correct. I think it should be because based on the code from huggingface's deepspeed_init.py
, it looks like it uses the files generated in global_step5
and doesn't rely on pytorch_model.bin
.
I think that this behavior is fine since we won't ever really use pytorch_model.bin
since we can just apply_lora
using the adapter_model
. You mentioned that the checkpoint should not only record weight but also the optimizer state and rng states, and I think that deep speed does save that in the checkpoint folder unless you are talking about something else.
Let me know what you think!
Good to hear that pytorch_model.bin
can be removed, but the size of global_stepx
should also be reduced to the size of adapters, but currently it also includes the backbone parameters(in case it's 13GB).
Have you tested to remove the parameter related file in the global_stepx
but only keep optimizer related ones, then load from the checkpoint? If that works, can you please try to delete the parameter removed file in your callback, or monkey-patches some hf/deepspeed code to avoid recording them?
I tested resuming from checkpoint while deleting all the files zero_pp_rank_x_mp_rank_00_model_states.pt
and I run into an AssertionError when loading it from the checkpoint presumably since it's looking for those files.
Is this pr still ongoing? Seems like the resuming from lora adapter still not having optimizer states and rng states.
I was working on this PR, but the best I can do so far is to be able to load from an adapter but we lose the LR schedule, optimizer state, and RNG state since it's a HF-side issue.
closed due to inactivity. Feel free to reopen