FastChat icon indicating copy to clipboard operation
FastChat copied to clipboard

Add callback on save for LoRA

Open BabyChouSr opened this issue 1 year ago • 7 comments

Why are these changes needed?

When we save the model checkpoint, it saves the entire model in pytorch_model.bin which is extremely large, but we often only want the adapter model which is defined by an adapter_model.bin and adapter_config.json file. These changes create a callback such that we save the adapter model separately as well so that we can apply the LoRA model immediately using apply_lora.py

Related issue number (if applicable)

Fixes #1249

Checks

  • [x] I've run format.sh to lint the changes in this PR.
  • [x] I've included any doc changes needed.
  • [ ] I've made sure the relevant tests are passing (if applicable).

BabyChouSr avatar May 16 '23 21:05 BabyChouSr

lgtm. Have you tested the new checkpoint size and reloading from the new checkpoint?

ZYHowell avatar May 16 '23 23:05 ZYHowell

Yes, reloading from the new checkpoint still works, and the checkpoint size for the adapter_model was 17M for me. I was able to test apply_lora and it works with the adapter model that was produced.

BabyChouSr avatar May 17 '23 02:05 BabyChouSr

I discovered and tested that the program (following this thread) will not create the pytorch_model.bin file if we set "stage3_gather_16bit_weights_on_model_save": false in ds_config.json and the checkpoint resumes. So, now the memory is reduced in half since we don't save the pytorch_model.bin file.

What I'm curious about is how the model even loads from the checkpoint without pytorch_model.bin and whether that's correct. I think it should be because based on the code from huggingface's deepspeed_init.py, it looks like it uses the files generated in global_step5 and doesn't rely on pytorch_model.bin.

I think that this behavior is fine since we won't ever really use pytorch_model.bin since we can just apply_lora using the adapter_model. You mentioned that the checkpoint should not only record weight but also the optimizer state and rng states, and I think that deep speed does save that in the checkpoint folder unless you are talking about something else.

Let me know what you think!

BabyChouSr avatar May 18 '23 00:05 BabyChouSr

Good to hear that pytorch_model.bin can be removed, but the size of global_stepx should also be reduced to the size of adapters, but currently it also includes the backbone parameters(in case it's 13GB). Have you tested to remove the parameter related file in the global_stepx but only keep optimizer related ones, then load from the checkpoint? If that works, can you please try to delete the parameter removed file in your callback, or monkey-patches some hf/deepspeed code to avoid recording them?

ZYHowell avatar May 18 '23 01:05 ZYHowell

I tested resuming from checkpoint while deleting all the files zero_pp_rank_x_mp_rank_00_model_states.pt and I run into an AssertionError when loading it from the checkpoint presumably since it's looking for those files.

BabyChouSr avatar May 18 '23 03:05 BabyChouSr

Is this pr still ongoing? Seems like the resuming from lora adapter still not having optimizer states and rng states.

ZYHowell avatar Jun 14 '23 09:06 ZYHowell

I was working on this PR, but the best I can do so far is to be able to load from an adapter but we lose the LR schedule, optimizer state, and RNG state since it's a HF-side issue.

BabyChouSr avatar Jun 14 '23 13:06 BabyChouSr

closed due to inactivity. Feel free to reopen

merrymercy avatar Jul 05 '23 09:07 merrymercy