peft
peft copied to clipboard
Optimize DoRA in `eval` and `no dropout`
Fixes #2107
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
@BenjaminBossan could you help me with the initial thought for an example that explains the efficiency of the process. As you have said, you would like to see memory_usage and faster_training. I am not very sure how I could explain the efficiency of memory usage in a colab notebook. Any resource would be very helpful.
@ariG23498 Could you please run make style so that the CI can pass?
Hi @ariG23498 and @BenjaminBossan, thanks for the effort, I have reviewed the code and found no issue with it. Please let me know or tag me if you need my help, thanks!
Thanks @ariG23498 for the latest fixes and @nbasyl for the review.
I did a small test using this DoRA script by calling:
CUDA_VISIBLE_DEVICES=0 time python dora_finetuning.py --quantize --lora_dropout 0 --use_dora
(I changed grad acc steps from 16 to 2). For this to work, I had to propagate the DoRA changes from this PR to the bitsandbytes layers.
What I found is:
- PEFT main:
{'train_runtime': 14.0129, 'train_samples_per_second': 0.714, 'train_steps_per_second': 0.357, 'train_loss': 10.531291198730468, 'epoch': 0.0} - This PR:
{'train_runtime': 11.8011, 'train_samples_per_second': 0.847, 'train_steps_per_second': 0.424, 'train_loss': 10.531893920898437, 'epoch': 0.0}
I also monitored memory and it went down from 7557MiB to 7325MiB.
So the final losses are not 100% identical, but I think it's within rounding error. Runtime was improved and memory usage slightly decreased with this PR.
Overall, I believe these are nice results and we can continue with this PR. @ariG23498 could you please propagate the changes to the quantized LoRA layers types that support it. We could probably also document this to let users know that they should consider disabling dropout for DoRA training to benefit from this optimization, with some numbers to underline this.
Thanks for the detailed reply @BenjaminBossan I am glad that the memory usage went down and the runtime also improved.
@ariG23498 could you please propagate the changes to the quantized LoRA layers types that support it.
Do you mean all the variants found here? Also I think it would be better to have the current change made to DoRA only, and then create another PR for the rest of the layers, WDYT?
@ariG23498 @BenjaminBossan very nice PR. I learned a lot. @ariG23498 let me know if I can be of help to propagate the changes, maybe in separate PR.
Do you mean all the variants found here? Also I think it would be better to have the current change made to DoRA only, and then create another PR for the rest of the layers, WDYT?
Yes, so what I mean is that e.g. in lora/bnb.py, the DoRA call has to be adjusted in the same way as in lora/layer.py. It should be quite straightforward, because you can reuse the same code everywhere. I'd prefer this to be consistent before merging the PR.
Other than that, let's add a mention of this optimization in the docs. Ideally, we can add some numbers. I can re-run the experiment mentioned above and give some definitive numbers if you want.
let me know if I can be of help to propagate the changes, maybe in separate PR.
Thanks for the offer. As mentioned, let's try to get it into this PR. @ariG23498 up to you if/how you want to split up the work.
@BenjaminBossan I have made the changes.
@charchit7 thank you for the offer, but as this is a redundant piece of code, I thought it was better to make the changes myself. Please feel free to take up other issues and comment for collaboration 🤗
Thanks for the update. Let's also add it here:
https://github.com/huggingface/peft/blob/93ddb1015a637e72c6e61a82852c7bb127b13d66/src/peft/tuners/lora/tp_layer.py#L214-L221
The other layer types don't seem to properly implement DoRA yet, so we can keep those to a separate PR.
Would you also be so kind to add to the docs?
@BenjaminBossan I have made the changes.
@charchit7 thank you for the offer, but as this is a redundant piece of code, I thought it was better to make the changes myself. Please feel free to take up other issues and comment for collaboration 🤗
Yes, I completely understand.
Thank you, yes, will do :)
Thanks for the updates. I'll re-run the script later, as the first test was only very short, to get some final numbers to report.
Update: So I re-ran the script for a while longer and with a higher batch size (before it was 1), using:
$CUDA_VISIBLE_DEVICES=0 time python examples/dora_finetuning/dora_finetuning.py --quantize --lora_dropout 0 --batch_size 16 --eval_step 2 --use_dora
I also set gradient_accumulation_steps=2 and max_steps=20.
What I found is that training was 20% (wall time) to 23% (transformes reported time) faster. However, there no longer was any memory advantage, not quite sure what was different the first time around.
- before:
{'train_runtime': 359.7298, 'train_samples_per_second': 1.779, 'train_steps_per_second': 0.056, 'total_flos': 1.303253870444544e+16, 'train_loss': 9.653419399261475, 'epoch': 0.06493506493506493, 'step': 20} - after
{'train_runtime': 279.2676, 'train_samples_per_second': 2.292, 'train_steps_per_second': 0.072, 'total_flos': 1.303253870444544e+16, 'train_loss': 9.643538236618042, 'epoch': 0.06493506493506493, 'step': 20}
Losses aligned quite nicely, with only a rounding error level of difference:
@ariG23498 Could you please update the docs accordingly (no need to mention the loss, as this is expected).
@BenjaminBossan the docs have been updated! The benchmark results are crazy.