diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

Passing additional state between subsequent RHS evaluations?

Open mbjd opened this issue 1 year ago • 2 comments

Hi there :)

I want to solve many ODEs which involve a small parameterised convex QP problem as a part of evaluating the right hand side. The QP is of the standard form:

min_x x' Q x + p' x s.t. Ax <= b

where the parameter p changes continuously, being a continuous function of the ODE state (specifically, I am solving the hamiltonian system given by Pontryagin's minimum principle in optimal control). Up until now, I've only addressed relatively small examples where it is easy to hand-write a brute force active set solver that basically considers all possible active sets, solves every possible KKT system, and in the end selects among those solutions the one with lowest objective that is within constraints.

I would like to approach more challenging systems, where the QP has enough constraints to render this basic approach completely infeasible. Instead, I'd like to use a general convex QP solver -- but because the parameter changes continously, we would benefit greatly from warm starting it with the previous solution.

Now my central question: would it be possible in diffrax to reuse primal/dual information from the last call? This would mathematically not break pureness of the RHS function, as the solution is still the same but we will find it more efficiently. However, from an implementation perspective it obviously would.

Some alternatives I've considered:

  1. Collect a number of pre-solved points for different parameter values, then at runtime warm-start any suitable convex solver with the closest one. Or, fit a small NN to approximate the map from QP parameter to primal/dual information for a similar warm starting scheme.
  2. Somehow precompute and/or approximate the solution map (without subsequent "solution polishing"), either starting from explicit convex optimisation (which has a bit less severe, but ultimately similar scalability issues as my current brute-force solution), or from amortised optimistion.
  3. Recognise that diffrax, or perhaps even jax, are not the best practical solution to this, and use a different ODE solver where additional state is less of an issue.
  4. Split the ODE up into different parts where the active set remains constant (making the QP solution linear in the parameter). Find the spot where the active set changes, and start the next part of the ODE with corresponding RHS there. a bit like this. Viewed from this angle we essentially have a hybrid system, where the RHS function "jumps" every time the active set changes -- maybe general hybrid system simulation tools could be useful here?
  5. Extend the ODE state with primal and dual information of the QP solution. Endow the ODE with dynamics that correspond to a continuous time convex optimiser to make sure the QP solution stays (close to) optimal. Probably keeping the QP solution accurate will render the system very stiff and unpleasant. Probably a contraption like this would be better formulated as a DAE though.

I'm thankful about any thoughts and suggestions :) (also i realise that alternative 5 and maybe others are a bit outlandish and are definitely not the practical solution here, just including it for completeness' sake)

mbjd avatar Dec 21 '23 10:12 mbjd

Hey there!

Indeed, this is something I've been wanting to add to Diffrax for a while (see also #299, which is a variant of this). This is something I'm going to try and get into one of the upcoming releases, but no promises!

Your suggestion 4 sounds like a possible workaround in the mean time: you could use a discrete_terminating_event to stop the solve, and then resume it again afterwards with the updated information.

Another possible workaround could be to write a custom variant of ODETerm that returns both the vector field and the auxiliary information, and a custom solver that operates on this (e.g. adapting just Euler wouldn't be too difficult -- adapting one of the Runge--Kutta methods would be more work though).

patrick-kidger avatar Dec 21 '23 22:12 patrick-kidger

(Also closely related to https://github.com/patrick-kidger/diffrax/issues/199)

slishak avatar Feb 02 '24 10:02 slishak