pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

Add `pytensor.tensor.optimize`

Open jessegrabowski opened this issue 10 months ago • 8 comments

Description

Implement scipy optimization routines, with implicit gradients. This PR should add:

  • [ ] optimize.minimize
  • [ ] optimize.root
  • [ ] optimize.scalar_minimize
  • [ ] optimize.scalar_root

It would also be nice to have rewrites to transform e.g. root to scalar_root when we know that there is only one input.

The implementation @ricardoV94 and I cooked up (ok ok it was mostly him) uses the graph to implicitly define the inputs to the objective function. For example:

    import pytensor.tensor as pt
    from pytensor.tensor.optimize import minimize
    x = pt.scalar("x")
    a = pt.scalar("a")
    c = pt.scalar("c")

    b = a * 2
    b.name = "b"
    out = (x - b * c) ** 2

    minimized_x, success = minimize(out, x, debug=False)

We optimize out with respect to x, so x becomes the control variable. By graph inspection we find that out also depends on a and c, so the generated graph includes them as parameters. In scipy lingo, we end up with:

minimize(fun=out, x0=x, args=(a, c)

We get the following graph. The inner graph includes the gradients of the cost function by default, which is automatically used by scipy.

MinimizeOp.0 [id A]
 ├─ x [id B]
 └─ Mul [id C]
    ├─ 2.0 [id D]
    ├─ a [id E]
    └─ c [id F]

Inner graphs:

MinimizeOp [id A]
 ← Pow [id G]
    ├─ Sub [id H]
    │  ├─ x [id I]
    │  └─ <Scalar(float64, shape=())> [id J]
    └─ 2 [id K]
 ← Mul [id L]
    ├─ Mul [id M]
    │  ├─ Second [id N]
    │  │  ├─ Pow [id G]
    │  │  │  └─ ···
    │  │  └─ 1.0 [id O]
    │  └─ 2 [id K]
    └─ Pow [id P]
       ├─ Sub [id H]
       │  └─ ···
       └─ Sub [id Q]
          ├─ 2 [id K]
          └─ DimShuffle{order=[]} [id R]
             └─ 1 [id S]

We can also ask for the gradients of the maximum value with respect to parameters:

x_grad, a_grad, c_grad = pt.grad(minimized_x, [x, a, c])

# x_grad.dprint()
0.0 [id A]

# a_grad.dprint()
Mul [id A]
 ├─ 2.0 [id B]
 └─ c [id C]

# c_grad.dprint()
Mul [id A]
 ├─ 2.0 [id B]
 └─ a [id C]

Related Issue

  • [ ] Closes #944
  • [ ] Related to #978 #https://github.com/pymc-devs/pymc-extras/issues/342

Checklist

Type of change

  • [x] New feature / enhancement
  • [ ] Bug fix
  • [ ] Documentation
  • [ ] Maintenance
  • [ ] Other (please specify):

📚 Documentation preview 📚: https://pytensor--1182.org.readthedocs.build/en/1182/

jessegrabowski avatar Jan 31 '25 16:01 jessegrabowski

We could swap the scipy method to powell if pt.grad raises?

ricardoV94 avatar Jan 31 '25 16:01 ricardoV94

Is there a path to later add other optimizers, e.g. jax or pytensor? Could then optimize on the GPU. Lasagne has a bunch of optimizers implemented in theano: https://lasagne.readthedocs.io/en/latest/modules/updates.html

twiecki avatar Jan 31 '25 18:01 twiecki

Is there a path to later add other optimizers, e.g. jax or pytensor? Could then optimize on the GPU. Lasagne has a bunch of optimizers implemented in theano: https://lasagne.readthedocs.io/en/latest/modules/updates.html

Yes, the MinimizeOp can be dispatched to JaxOpt on the jax backend. On the Python/C-backend we can also replace the MinimizeOp by any equilavent optimizer as well. As @jessegrabowski mentioned, we may even analyze the graph to decide what to use, such as scalar_root when that's adequate.

ricardoV94 avatar Jan 31 '25 18:01 ricardoV94

Is there a path to later add other optimizers, e.g. jax or pytensor? Could then optimize on the GPU. Lasagne has a bunch of optimizers implemented in theano: https://lasagne.readthedocs.io/en/latest/modules/updates.html

The Lasagne optimizers are for SGD in minibatch settings, so it's slightly different from what I have in mind here. This functionality would be useful in cases where you want to solve a sub-problem and then use the result in a downstream computation. For example, I think @theorashid wanted to use this for INLA to integrate out nuisance parameters via optimization before running MCMC on the remaining parameters of interest.

Another use case would be an agent-based model where we assume agents behave optimally. For example, we could try to estimate investor risk aversion parameters, assuming some utility function. The market prices would be the result of portfolio optimization subject to risk aversion (estimated), expected return vector, and market covariance matrix. Or use it in an RL-type scheme where agents have to solve a Bellman equation to get (an approximation to) their value function. I'm looking forward to cooking up some example models using this.

jessegrabowski avatar Feb 01 '25 05:02 jessegrabowski

We could swap the scipy method to powell if pt.grad raises?

We can definitely change the method via rewrites. There are several gradient-free (or approximate gradient) options in that respect. Optimizers can be fussy though, so I'm a bit hesitant to take this type of configuration out of the hands of the user.

jessegrabowski avatar Feb 01 '25 05:02 jessegrabowski

Optimizers can be fussy though, so I'm a bit hesitant to take this type of configuration out of the hands of the user.

Let users choose but try to provide the best default?

ricardoV94 avatar Feb 01 '25 08:02 ricardoV94

Carlos its interested in this 👀

cetagostini avatar Feb 10 '25 17:02 cetagostini

Would it be possible to insert a callback to the optimiser to store the position and gradient history?

aphc14 avatar Feb 16 '25 06:02 aphc14

I implemented all the things I said I would, so I'm taking this off draft.

The optimize.py file has gotten pretty unreadable, what do you think about making it a sub-module with root.py and minimize.py ? I could imagine some more functionality being requested that would essentially recycle the same implicit gradient code (fixed point iteration, for example).

We might also consider offering the scan-based pure pytensor one (maybe with an L_op override to use the implicit formula instead of backproping through the scan steps). That would have the advantage of working in other backends right out the gates, plus it could give users fine control over e.g. the convergence check function.

jessegrabowski avatar Jun 09 '25 11:06 jessegrabowski

~~Also failing sparse tests don't look like my fault~~ ~~needed to update my local main and rebase~~ nope it's stll failing

jessegrabowski avatar Jun 09 '25 11:06 jessegrabowski

This looks great! Left some comments about asserting the fgraph variables are of the expected type, otherwise this looks good. Do we want to follow up with JAX dispatch? I guess there's no off-the shelve numba stuff we can plug-in? Maybe some optional third-party library like we do with tensorflow-probability for some of the JAX dispatches?

ricardoV94 avatar Jun 10 '25 08:06 ricardoV94

Codecov Report

:x: Patch coverage is 90.81967% with 28 lines in your changes missing coverage. Please review. :white_check_mark: Project coverage is 82.17%. Comparing base (d10f245) to head (bfa63f6). :warning: Report is 136 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/tensor/optimize.py 90.81% 17 Missing and 11 partials :warning:
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1182      +/-   ##
==========================================
+ Coverage   82.12%   82.17%   +0.05%     
==========================================
  Files         211      212       +1     
  Lines       49757    50062     +305     
  Branches     8819     8840      +21     
==========================================
+ Hits        40862    41139     +277     
- Misses       6715     6732      +17     
- Partials     2180     2191      +11     
Files with missing lines Coverage Δ
pytensor/tensor/optimize.py 90.81% <90.81%> (ø)
:rocket: New features to boost your workflow:
  • :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

codecov[bot] avatar Jun 10 '25 11:06 codecov[bot]

Maybe more informative PR title as well?

ricardoV94 avatar Jun 10 '25 13:06 ricardoV94