jaxopt
jaxopt copied to clipboard
MAML example
TODO:
- [x] Frame it as a bi-level problem, as the implicit maml paper (https://arxiv.org/pdf/1909.04630.pdf)
- [x] Move to implicit diff directory
I think this is finally ready for review 🎉 !
one can see the notebook here: https://github.com/google/jaxopt/blob/d7fc1d5407c81ab3c70bb09b818101fbcb4e1e3e/docs/notebooks/implicit_diff/maml.ipynb
There are still a couple of issues left, but I believe these can be tackled after merge:
- The notebook takes a while to run, like a couple of hours (on CPU). Would be nice to bring that down.
- Convergence of the outer gradient seems to plateau around 0.05. I believe this arises because the gradient computed by implicit differentiation is not precise enough (it could be made more precise, but the runtime would then be extremely long)
@asteroidhouse might also be interested by this 😄
thanks for the review! Implemented the suggestions. Please add the pull-ready label if you're satisfied :-)