Laplace icon indicating copy to clipboard operation
Laplace copied to clipboard

Bringing laplace-torch to foundation-model era

Open wiseodd opened this issue 1 year ago • 2 comments

Main features of this pull request:

  1. Support only doing Laplace on params that require grad. Use case: PEFT (like LoRA) on top of a frozen foundation model. This is more efficient than SubnetLaplace since the latter still computes the full Jacobians.
  2. Add support to multiple leading dims in classification likelihood. E.g. the logits is (batch_size, seq_len, n_classes). Useful for language modeling and reward modeling.
    1. This PR also contains the integrate-latest-asdl changes. I tested it with my ASDL fork (only a couple of light changes to support weight-sharing dim and ignore_index; so crucial in language modeling): https://github.com/wiseodd/asdl/commits/dev/. Please also check this and let me know what's the most elegant way.
  3. Support Huggingface dataset. The assumption is that x is a UserDict containing input_ids, attention_mask, etc., things that are produced by HF dataloader.
  4. Add a new likelihood called reward_modeling where the classification likelihood is used during training and the regression likelihood is used during prediction.
  5. Add support to torchmetrics for gridsearch. The benefit is that it supports running metrics => less memory overhead (vis-a-vis gathering all the logits first).
  6. Add Jacobian computation with torch.func (functorch, really) as a general Jacobian computation for GLM predictive. Useful for Bayesian optimization/invariance learning where you need to backprop through the variance. Much more elegant than to change ASDL.

Relevant unit tests are provided. All tests passed; the only ones failed are the old LowRankLaplace issues.

wiseodd avatar Feb 24 '24 13:02 wiseodd

Thanks Agustinus for this PR, this is very useful! Just took a glance at the parts that I'm familiar with, which looks good to me. I'll defer to Alex and Runa for a more in-depth review due to their better familiarity with the core library.

edaxberger avatar Feb 26 '24 11:02 edaxberger

This PR is currently blocked by https://github.com/f-dangel/curvlinops/issues/86. Need to make curvlinops supports Huggingface datasets first.

wiseodd avatar Mar 10 '24 21:03 wiseodd

Regarding the addition of logit_class_idx argument: I did a "smoke test" in implementing this, and realized that the changes would be substantial---even without the present PR, laplace-torch always assumes that logit_class_idx == -1.

So, I think this is better done in a separate PR. This will also make the scope of the PRs cleaner---in this PR, we focus on HF LLM models.

wiseodd avatar Apr 18 '24 17:04 wiseodd

So, the remaining TODOs for this PR are:

  1. Decide the conversations about matrix.py. @runame & @aleximmer could you please handle this?
  2. Wait until https://github.com/f-dangel/curvlinops/pull/100 is merged

Then we can merge this PR and continue with further PRs that we deferred in this PR.

wiseodd avatar Apr 18 '24 17:04 wiseodd

Sounds good, can you create issues for all comments that you want to address in future PRs?

runame avatar Apr 18 '24 18:04 runame

I replaced get_nll in crossval with RunningNLLMetrics(). So this PR will close #160 . (I thought before that I RunningNLLMetics hadn't been implemented, so would have involve major work.)

wiseodd avatar Apr 25 '24 22:04 wiseodd

Might as well fixes #156 while we're at it.

wiseodd avatar Apr 25 '24 23:04 wiseodd

All tasks finished!

wiseodd avatar Apr 26 '24 16:04 wiseodd

Double-checked and everything looks fine! Merging this.

wiseodd avatar Apr 27 '24 18:04 wiseodd