Laplace
Laplace copied to clipboard
Bringing laplace-torch to foundation-model era
Main features of this pull request:
- Support only doing Laplace on params that require grad. Use case: PEFT (like LoRA) on top of a frozen foundation model. This is more efficient than SubnetLaplace since the latter still computes the full Jacobians.
- Add support to multiple leading dims in
classification
likelihood. E.g. the logits is(batch_size, seq_len, n_classes)
. Useful for language modeling and reward modeling.- This PR also contains the
integrate-latest-asdl
changes. I tested it with my ASDL fork (only a couple of light changes to support weight-sharing dim andignore_index
; so crucial in language modeling): https://github.com/wiseodd/asdl/commits/dev/. Please also check this and let me know what's the most elegant way.
- This PR also contains the
- Support Huggingface dataset. The assumption is that
x
is aUserDict
containinginput_ids
,attention_mask
, etc., things that are produced by HF dataloader. - Add a new likelihood called
reward_modeling
where theclassification
likelihood is used during training and theregression
likelihood is used during prediction. - Add support to
torchmetrics
for gridsearch. The benefit is that it supports running metrics => less memory overhead (vis-a-vis gathering all the logits first). - Add Jacobian computation with
torch.func
(functorch
, really) as a general Jacobian computation for GLM predictive. Useful for Bayesian optimization/invariance learning where you need to backprop through the variance. Much more elegant than to change ASDL.
Relevant unit tests are provided. All tests passed; the only ones failed are the old LowRankLaplace
issues.
Thanks Agustinus for this PR, this is very useful! Just took a glance at the parts that I'm familiar with, which looks good to me. I'll defer to Alex and Runa for a more in-depth review due to their better familiarity with the core library.
This PR is currently blocked by https://github.com/f-dangel/curvlinops/issues/86. Need to make curvlinops
supports Huggingface datasets first.
Regarding the addition of logit_class_idx
argument: I did a "smoke test" in implementing this, and realized that the changes would be substantial---even without the present PR, laplace-torch
always assumes that logit_class_idx == -1
.
So, I think this is better done in a separate PR. This will also make the scope of the PRs cleaner---in this PR, we focus on HF LLM models.
So, the remaining TODOs for this PR are:
- Decide the conversations about
matrix.py
. @runame & @aleximmer could you please handle this? - Wait until https://github.com/f-dangel/curvlinops/pull/100 is merged
Then we can merge this PR and continue with further PRs that we deferred in this PR.
Sounds good, can you create issues for all comments that you want to address in future PRs?
I replaced get_nll
in crossval with RunningNLLMetrics()
. So this PR will close #160 .
(I thought before that I RunningNLLMetics
hadn't been implemented, so would have involve major work.)
Might as well fixes #156 while we're at it.
All tasks finished!
Double-checked and everything looks fine! Merging this.