jaxopt icon indicating copy to clipboard operation
jaxopt copied to clipboard

MAML example

Open fabianp opened this issue 3 years ago • 5 comments

fabianp avatar Jun 07 '22 19:06 fabianp

TODO:

  • [x] Frame it as a bi-level problem, as the implicit maml paper (https://arxiv.org/pdf/1909.04630.pdf)
  • [x] Move to implicit diff directory

fabianp avatar Jun 07 '22 20:06 fabianp

I think this is finally ready for review 🎉 !

fabianp avatar Sep 20 '22 17:09 fabianp

one can see the notebook here: https://github.com/google/jaxopt/blob/d7fc1d5407c81ab3c70bb09b818101fbcb4e1e3e/docs/notebooks/implicit_diff/maml.ipynb

fabianp avatar Sep 20 '22 17:09 fabianp

There are still a couple of issues left, but I believe these can be tackled after merge:

  1. The notebook takes a while to run, like a couple of hours (on CPU). Would be nice to bring that down.
  2. Convergence of the outer gradient seems to plateau around 0.05. I believe this arises because the gradient computed by implicit differentiation is not precise enough (it could be made more precise, but the runtime would then be extremely long)

fabianp avatar Sep 20 '22 17:09 fabianp

@asteroidhouse might also be interested by this 😄

fabianp avatar Sep 20 '22 17:09 fabianp

thanks for the review! Implemented the suggestions. Please add the pull-ready label if you're satisfied :-)

fabianp avatar Sep 27 '22 09:09 fabianp