Add `pytensor.tensor.optimize`
Description
Implement scipy optimization routines, with implicit gradients. This PR should add:
- [ ]
optimize.minimize - [ ]
optimize.root - [ ]
optimize.scalar_minimize - [ ]
optimize.scalar_root
It would also be nice to have rewrites to transform e.g. root to scalar_root when we know that there is only one input.
The implementation @ricardoV94 and I cooked up (ok ok it was mostly him) uses the graph to implicitly define the inputs to the objective function. For example:
import pytensor.tensor as pt
from pytensor.tensor.optimize import minimize
x = pt.scalar("x")
a = pt.scalar("a")
c = pt.scalar("c")
b = a * 2
b.name = "b"
out = (x - b * c) ** 2
minimized_x, success = minimize(out, x, debug=False)
We optimize out with respect to x, so x becomes the control variable. By graph inspection we find that out also depends on a and c, so the generated graph includes them as parameters. In scipy lingo, we end up with:
minimize(fun=out, x0=x, args=(a, c)
We get the following graph. The inner graph includes the gradients of the cost function by default, which is automatically used by scipy.
MinimizeOp.0 [id A]
├─ x [id B]
└─ Mul [id C]
├─ 2.0 [id D]
├─ a [id E]
└─ c [id F]
Inner graphs:
MinimizeOp [id A]
← Pow [id G]
├─ Sub [id H]
│ ├─ x [id I]
│ └─ <Scalar(float64, shape=())> [id J]
└─ 2 [id K]
← Mul [id L]
├─ Mul [id M]
│ ├─ Second [id N]
│ │ ├─ Pow [id G]
│ │ │ └─ ···
│ │ └─ 1.0 [id O]
│ └─ 2 [id K]
└─ Pow [id P]
├─ Sub [id H]
│ └─ ···
└─ Sub [id Q]
├─ 2 [id K]
└─ DimShuffle{order=[]} [id R]
└─ 1 [id S]
We can also ask for the gradients of the maximum value with respect to parameters:
x_grad, a_grad, c_grad = pt.grad(minimized_x, [x, a, c])
# x_grad.dprint()
0.0 [id A]
# a_grad.dprint()
Mul [id A]
├─ 2.0 [id B]
└─ c [id C]
# c_grad.dprint()
Mul [id A]
├─ 2.0 [id B]
└─ a [id C]
Related Issue
- [ ] Closes #944
- [ ] Related to #978 #https://github.com/pymc-devs/pymc-extras/issues/342
Checklist
- [ ] Checked that the pre-commit linting/style checks pass
- [ ] Included tests that prove the fix is effective or that the new feature works
- [ ] Added necessary documentation (docstrings and/or example notebooks)
- [ ] If you are a pro: each commit corresponds to a relevant logical change
Type of change
- [x] New feature / enhancement
- [ ] Bug fix
- [ ] Documentation
- [ ] Maintenance
- [ ] Other (please specify):
📚 Documentation preview 📚: https://pytensor--1182.org.readthedocs.build/en/1182/
We could swap the scipy method to powell if pt.grad raises?
Is there a path to later add other optimizers, e.g. jax or pytensor? Could then optimize on the GPU. Lasagne has a bunch of optimizers implemented in theano: https://lasagne.readthedocs.io/en/latest/modules/updates.html
Is there a path to later add other optimizers, e.g. jax or pytensor? Could then optimize on the GPU. Lasagne has a bunch of optimizers implemented in theano: https://lasagne.readthedocs.io/en/latest/modules/updates.html
Yes, the MinimizeOp can be dispatched to JaxOpt on the jax backend. On the Python/C-backend we can also replace the MinimizeOp by any equilavent optimizer as well. As @jessegrabowski mentioned, we may even analyze the graph to decide what to use, such as scalar_root when that's adequate.
Is there a path to later add other optimizers, e.g. jax or pytensor? Could then optimize on the GPU. Lasagne has a bunch of optimizers implemented in theano: https://lasagne.readthedocs.io/en/latest/modules/updates.html
The Lasagne optimizers are for SGD in minibatch settings, so it's slightly different from what I have in mind here. This functionality would be useful in cases where you want to solve a sub-problem and then use the result in a downstream computation. For example, I think @theorashid wanted to use this for INLA to integrate out nuisance parameters via optimization before running MCMC on the remaining parameters of interest.
Another use case would be an agent-based model where we assume agents behave optimally. For example, we could try to estimate investor risk aversion parameters, assuming some utility function. The market prices would be the result of portfolio optimization subject to risk aversion (estimated), expected return vector, and market covariance matrix. Or use it in an RL-type scheme where agents have to solve a Bellman equation to get (an approximation to) their value function. I'm looking forward to cooking up some example models using this.
We could swap the scipy method to powell if pt.grad raises?
We can definitely change the method via rewrites. There are several gradient-free (or approximate gradient) options in that respect. Optimizers can be fussy though, so I'm a bit hesitant to take this type of configuration out of the hands of the user.
Optimizers can be fussy though, so I'm a bit hesitant to take this type of configuration out of the hands of the user.
Let users choose but try to provide the best default?
Carlos its interested in this 👀
Would it be possible to insert a callback to the optimiser to store the position and gradient history?
I implemented all the things I said I would, so I'm taking this off draft.
The optimize.py file has gotten pretty unreadable, what do you think about making it a sub-module with root.py and minimize.py ? I could imagine some more functionality being requested that would essentially recycle the same implicit gradient code (fixed point iteration, for example).
We might also consider offering the scan-based pure pytensor one (maybe with an L_op override to use the implicit formula instead of backproping through the scan steps). That would have the advantage of working in other backends right out the gates, plus it could give users fine control over e.g. the convergence check function.
~~Also failing sparse tests don't look like my fault~~ ~~needed to update my local main and rebase~~ nope it's stll failing
This looks great! Left some comments about asserting the fgraph variables are of the expected type, otherwise this looks good. Do we want to follow up with JAX dispatch? I guess there's no off-the shelve numba stuff we can plug-in? Maybe some optional third-party library like we do with tensorflow-probability for some of the JAX dispatches?
Codecov Report
:x: Patch coverage is 90.81967% with 28 lines in your changes missing coverage. Please review.
:white_check_mark: Project coverage is 82.17%. Comparing base (d10f245) to head (bfa63f6).
:warning: Report is 136 commits behind head on main.
| Files with missing lines | Patch % | Lines |
|---|---|---|
| pytensor/tensor/optimize.py | 90.81% | 17 Missing and 11 partials :warning: |
Additional details and impacted files
@@ Coverage Diff @@
## main #1182 +/- ##
==========================================
+ Coverage 82.12% 82.17% +0.05%
==========================================
Files 211 212 +1
Lines 49757 50062 +305
Branches 8819 8840 +21
==========================================
+ Hits 40862 41139 +277
- Misses 6715 6732 +17
- Partials 2180 2191 +11
| Files with missing lines | Coverage Δ | |
|---|---|---|
| pytensor/tensor/optimize.py | 90.81% <90.81%> (ø) |
:rocket: New features to boost your workflow:
- :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
Maybe more informative PR title as well?