diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

Solving for initial value conditions and final value conditions

Open enigmarikki opened this issue 1 year ago • 10 comments

Hey @patrick-kidger , I recently started exploring diffrax (just realised how powerful diffrax is!, and kudos for the amazing work!). I was thinking if there was a way to solve the initial value problem and a final value problem, viz, a system of stiff ODE's solving for ${ y(0) = k1, y(T) = k2}$, (A boundary value problem instead of an initial value problem) any kind of reference would be much appreciated.

enigmarikki avatar Jul 10 '22 18:07 enigmarikki

A standard approach for something like this would be a shooting method.

patrick-kidger avatar Jul 11 '22 08:07 patrick-kidger

Hi @patrick-kidger, I am interested in implementing a collocation-based BVP solver, since for my particular application shooting methods do not work well. Is there current work along this direction for Diffrax?

Good references might be the MIRK4 solver in Julia or scipy's solve_bvp. Are you aware of any additional resources that might be useful? Would porting any of these methods to Diffrax be a herculian task?

RicardoDominguez avatar Feb 01 '23 12:02 RicardoDominguez

So there's no current work in this direction, but I suspect it shouldn't be that hard! (I don't have any additional resources beyond the ones you link / the ones that are referenced in the scipy documentation.)

Right now, diffeqsolve specifically handles initial value problems. (E.g. it only accepts a y0.) So to solve a 2-point BVP we would want a new interface (biffeqsolve? :D ) that accepts a y1 as well. One could then imagine calling biffeqsolve(..., y0=y0, y1=y1, solver=MIRK4()), biffeqsolve(..., y0=y0, y1=y1, solver=Shooting(Tsit5())). Where appropriate this could internally lower to diffeqsolve + a nonlinear solver.

If you manage to stand up a prototype then I'd be interested to see the result. I'd be happy to see this either as a third-party library, or as a new interface in Diffrax.

(Regarding the nonlinear solver portion: for a prototype it'd probably be simplest to just use diffrax.NewtonNonlinearSolver, although that isn't necessarily well-adapted for this kind of nonlinear problem. At some point Soon™ I will have some more sophisticated handling for nonlinear solvers coming out, and we can then use that.)

patrick-kidger avatar Feb 03 '23 07:02 patrick-kidger

Hi @patrick-kidger, I have partially adapted scipy.integrate.solve_bvp (without support for unknown parameters p), which you can see here. The collocation system is solved by a damped Newton method which alternates between full Newton iterations and fixed-Jacobian iterations. Right now it is standalone of Diffrax but it would be interesting to integrate as a new interface in Diffrax. Three comments/questions:

  • I have tried to use diffrax.NewtonNonlinearSolver to solve the collocation system (as a drop-in replacement in line 582) but it does not seem to work well for this particular problem. I suspect that the Newton method being damped is important. Thoughts?
  • The collocation system being solved is very sparse and scipy makes uses of this fact by using scipy.sparse.linalg.splu however using sparse solvers in Jax (jax.scipy.sparse.linalg.cg, jax.scipy.sparse.linalg.gmres) is slower than jax.scipy.linalg.lu_factor for moderately large matrices.
  • Scipy's implementation iteratively increases the mesh nodes based on the residuals. I suppose that this behavior is not easily batchable unless the mesh is always expanded by the same number of nodes at each iteration. Thus it might be more performant to use a fixed but relatively large mesh rather than starting with a small mesh and iteratively expanding it. Thoughts?

I am not very well-versed in Jax so I am happy to take feedback :)

RicardoDominguez avatar Feb 07 '23 14:02 RicardoDominguez

I've just had a look over. I think it looks really good!

(Also you say you're new to JAX, but many of your points -- e.g. batching the mesh -- are pretty advanced things to spot.)

To answer your questions:

  • Damped Newton: I agree, this is probably important. The current Newton implementation is a simple Newton method or Chord method, as that is all that's needed for implicit RK solvers. There will soon be a much more sophisticated system in place here; it would then be easy to expand the available nonlinear solvers to include Newton with damping, backtracking etc.

  • Sparse: when you tried using the sparse solvers, did you specify the matrix as function describing the matrix-vector product, or as a materialised matrix? The former should probably be much more efficient for this problem (at least on the forward pass, at least in terms of runtime-without-compiletime). If it was a materialised matrix, was it as a sparse or as a dense matrix?

    For what it's worth, JAX's support for sparse arrays is still pretty limited. (I suspect this is still at least a couple of years away from being anything comprehensive.) Both JAX and Diffrax place much greater focus on dense arrays.

    I do note that the system being solved should exhibit quite a lot of structure (just staring at the code I think it's block tridiagonal+dense?) One possible path forward is to use a custom linear solver that exploits this property. (This is also something that should be a lot easier with the soon-ish improvements to the linear and nonlinear solvers.)

  • Mesh nodes: the usual way to handle batching for this kind of variable-size computation is to use a computational structure that does as much work as the "most demanding" batch element.

    To give an example: vmap(lax.while_loop) works by iterating until every batch element has completed its computation; all of the other batch elements will still have their body_fun evaluated, but the result of the computation will simply be thrown away.

    Here, then: probably what you want to do is to something like nodes_added = unvmap_max(nodes_added). Here unvmap_max is a function which is just the identity when run without vmap, but takes the maximum over all vmap'd dimensions when run with vmap. It's available at equinox.internal.unvmap_max. The "most demanding" batch element will get as many nodes as it requires, and the other batch elements will just be evaluated on a few more nodes than they really need. (Actually, I can see a bit more work is required here, as insert_1 and insert_2 are dynamically-sized arrays at the moment, and you'll need to find a way to express the computation using statically-sized arrays, but hopefully you get the idea.)

Other comments:

  • Right now solve_bvp isn't really jit'able, as it still has a Python while loop at the top level.

    This has a couple of issues: (a) it introduces a lot of overhead relative to doing the whole thing in JAX and then jit'ing solve_bvp; (b) it breaks composability with the rest of JAX i.e. a downstream user can't use solve_bvp in an arbitrary spot in their code. (Admittedly, if you're the only downstream user right now: maybe you don't mind!)

    Anyway, I'd recommend switching out this Python while loop with e.g. lax.while_loop, and making the other appropriate adjustments, e.g. no print statements. (It's for these kinds of reasons that diffeqsolve has a totally different API to solve_ivp: JAX's programming model is actually pretty different to NumPy, despite the similarity at the jax.numpy level.)

    You might rightly complain that lax.while_loop isn't reverse-mode autodifferentiable. (a) I have a reverse-mode autodifferentiable version of while_loop coming out in the next few days, as a drop-in replacement; (b) it's not clear to me that you actually want to directly differentiate the iterations anyway: I'm guessing there's probably e.g. an implicit representation of the adjoint (c.f. diffrax.ImplicitAdjoint) that would be a more efficient representation of the tangent/cotangent operations anyway.

  • I note that you call fun several times. Do be aware that each separate time you call fun, JAX will need to trace and compile it afresh (it can't know that it's already seen fun from earlier), and this can start to increase compile times.

    Up to a certain point this is inevitable (I think diffeqsolve traces through the vector field ~5 times or so), but if you do see ways to minimise the number of times that fun is called then this can be useful.

WDYT?

patrick-kidger avatar Feb 14 '23 08:02 patrick-kidger

Hi @patrick-kidger, thank you so much for the detailed feedback! I am stuck trying to make solve_bvp jittable by swapping out the Python while loop at the top level for a lax.while_loop. As far as I understand, lax.while_loop requires val to have fixed shape, however we would like to expand the mesh at every iteration (i.e., x and y to have increasing size). I thought that one solution would be to let x and y have static size max_number_nodes, where only the first n_mesh_nodes are passed to the Newton solver (i.e., newton_solver.solve(y[:n_mesh_nodes, x[:n_mesh_nodes])), and n_mesh_nodes is increased at every iteration. However, it is not possible to index by a dynamic n_mesh_nodes (even slize_sizes of lax.dynamic_slice must be static). Do you have any suggestions?

RicardoDominguez avatar Feb 27 '23 15:02 RicardoDominguez

Right; JAX unfortunately has poor support for dynamic shapes.

I think I can see a way to make this work, but it's pretty tricky.

To start, let's consider doing a vector-vector dot product with dynamic shapes. This can be done by representing the vectors using statically-sized buffers that we only partially fill, and then using a lax.while_loop to perform the actual computation:

def dot(a: Float[Array, "n"], b: Float[Array, "n"], size: Int[Array, ""]):
    # Only a[:size] and b[:size] should be filled; the rest is padding.
    def cond_fun(carry):
        step, _ = carry
        return step < size

    def body_fun(carry):
        step, val = carry
        val = val + a[step] * b[step]
        step = step + 1
        return step, val
    _, out = lax.while_loop(cond_fun, body_fun, (0, 0))
    return out

We can now upgrade that to a matrix-vector product like so:

def matvec(A: Float[Array, "n n"], b: Float[Array, "n"], size: Int[Array, ""]):
    # Only A[:size, :size] and b[:size] should be filled; the rest is padding.
    def cond_fun(carry):
        step, _ = carry
        return step < size
    def body_fun(carry):
        step, val = carry
        val = val.at[step].set(dot(A[step], b, size))
        step = step + 1
        return step, val
    out = jnp.zeros_like(b)
    _, out = lax.while_loop(cond_fun, body_fun, (0, out))
    return out

and with this we can move up to getting linear solvers, some of which only need matrix-vector products (without e.g. doing a LU decomposition of the entire matrix, which may be potentially large due to all the padding)

def linear_solve(A: Float[Array, "n n"], b: Float[Array, "n"], size: Int[Array, ""]):
    # Only A[:size, :size] and b[:size] should be filled; the rest is padding.
    out, _ = jax.scipy.sparse.gmres(ft.partial(matvec, A, size=size), b)
    return out

and with that we can then build a nonlinear solver callable asnewton_solver.solve(y, x, n_mesh_nodes), which uses the above linear solver. (Exercise for the reader.)

Obviously threading all of that through is less-than-ideal. Ultimately this is case in which we're bumping up against the limitations of JAX's model of computation. At some point there is a plan that we'll get dynamic shapes in JAX, but I don't think that's going to happen any time soon. I'm sorry I don't have something more positive to say!

patrick-kidger avatar Feb 28 '23 03:02 patrick-kidger

Hi @patrick-kidger, thanks again for the very helpful response! I see how it would be possible make it work regarding iteratively refining the mesh.

For now, I have decided to use a fixed-length mesh. In the updated implementation, the mesh points are updated based on a global strategy described by [1] Chapter 9.3. Such global strategies tend to not work as well as iteratively refining the mesh, and of course require the initial mesh to be sufficiently fine. Since my ODE function is not very expensive to evaluate, I can afford to use large initial meshes, allowing me to obtain very accurate solutions.

I have also made additional changes to address your earlier feedback:

  • Solving the sparse linear system: the linear system being solved is bordered almost block diagonal (BABD). I found JAX's sparse solvers to be very slow, so I implemented a simple solver for BABD matrices which is fast enough for my purposes.
  • solve_bvp is now jittable and vmappable.
  • I reduced the number of calls to the ODE function. Now fun compiles 7 times when compiling solve_bvp, which is still not great but nonetheless should be a bit of an improvement.

Once I have a bit more free time, I would like to attempt your suggested approach for dynamically refining the mesh. It would also be interesting to integrate the solver within Diffrax, but maybe it is best to wait until you release the new framework for handling nonlinear solvers?

[1] U. Ascher, R. Mattheij and R. Russell "Numerical Solution of Boundary Value Problems for Ordinary Differential Equations", 1995.

RicardoDominguez avatar Mar 06 '23 21:03 RicardoDominguez

This is incredibly cool to see. I had a quick look over and your implementation looks good to me!

  • It's great that you managed to get JIT and vmap working.
  • Likewise, I'm glad to see the succes of the BABD solver. Once the aforementioned linear/nonlinear package comes out then this would be a great match with that. (I've looked at your implementation and it should be compatible!)

In terms of reducing the number of times the ODE function is called.

  • In this function, then vmap(fun) is called twice in quick succession. You might be able to package that up into a length-2 scan, so that it is only traced once. (This admittedly reduces readability.)
  • backtrack_cost (which calls fun) is traced twice here, once at initialisation and once again in the step. It might be possible to combine this into just one call, with some appropriate dummy initialisation of the variables. For something a little similar see the Newton implementation in Diffrax itself -- this uses a step counter to ensure that we make the minimum number of required steps for our diffsize and diffsize_prev variables to actually be meaningful, and just initialises them to dummy zeros outside the loop.

Regarding backpropagation through all of this: I can now make good on my earlier promise. Equinox now has a reverse-mode autodifferentiable while loop available as equinox.internal.while_loop(..., kind="checkpointed"). (The only real footgun is to make sure to use its buffers argument with anything that you update in-place (.at[].set()), but I don't think you're using that so you should be safe.) No idea if this is really the correct thing to do of course, as I said before there might be more elegant ways to differentiate a BVP solve.

If you manage to dynamically refine the mesh as per my previous message then I'd be very impressed; that looks pretty hellish to get working all the way through!

Integration into Diffrax: yup, I think waiting on the new linear/nonlinear solvers makes most sense. Once that happens I'd definitely be interested in considering that.

patrick-kidger avatar Mar 08 '23 22:03 patrick-kidger

(Regarding the nonlinear solver portion: for a prototype it'd probably be simplest to just use diffrax.NewtonNonlinearSolver, although that isn't necessarily well-adapted for this kind of nonlinear problem. At some point Soon™ I will have some more sophisticated handling for nonlinear solvers coming out, and we can then use that.)

Coming back to this -- we (finally!) released Optimistix, as an Equinox-based library for nonlinear optimisation. As that's such an important part of BVP solvers, perhaps it will be of interest to those here.

patrick-kidger avatar Oct 06 '23 17:10 patrick-kidger