jaxopt
jaxopt copied to clipboard
Expression tree API like CVXPY
Does jaxopt have an expression tree API like CVXPY does?
If not, would it be possible to create one? This would make it easier to build up problems.
Hello @carlosgmartin, Wouldn't jaxprs be what you are looking for? JAX traces the whole computational graph so that you could directly work on it. I may not have got what you are looking for. Feel free to detail more your objective.
@vroulet I meant something like the example shown in the second link (more examples here). Here's another example:
def find_nash_equilibrium(u):
"""Find the Nash equilibrium of a two-player zero-sum normal-form game.
u is the payoff matrix for the row player."""
x = cp.Variable(u.shape[0])
v = cp.Variable()
objective = cp.Maximize(v)
constraints = [
v <= x @ u,
x >= 0,
x.sum() == 1,
]
problem = cp.Problem(objective, constraints)
result = problem.solve()
return {"v": v.value, "x": x.value, "y": constraints[0].dual_value}
That is, letting users create variables and build up expressions from them to create objectives and constraints for a desired problem. This makes it easier for users to write linear/quadratic programs than manually fiddling with the A, b, G, h, Q, c arrays.
So no, we haven't that now and we were not planning on adding it. Wouldn't a package like https://github.com/cvxgrp/cvxpylayers be a good starting point?