Support Lineax operators for the vector field of ControlTerm
By providing support for lineax.AbstractLinearOperators in the vector field of a ControlTerm it may be possible to reduce the need for WeaklyDiagonalControlTerm and/or any other such specialised control terms.
The following gives a MWE:
class ControlTerm(_ControlTerm):
def prod(self, vf, control):
if isinstance(vf, lx.AbstractLinearOperator):
return jtu.tree_map(lambda _vf, _control: _vf.mv(_control), vf, control)
return jtu.tree_map(_prod, vf, control)
# These two are now equivalent.
ControlTerm(lx.DiagonalLinearOperator(jnp.array([1,2,3])), ...)
WeaklyDiagonalControlTerm(jnp.array([1,2,3]), ...)
Not sure if this is something you want to support, but it occurred to me that the operator tags might also be useful for some of the diffrax solvers?
Yup, with the release of Lineax this is something I've been considering! This also dovetails nicely with #364, so that we may wish to also introduce a type parameter for the return type of the vector field.
(One very nitty concern I do have is that mathematically speaking, we tend to interpret f(y) dx as actually being a linear function dx -> f( . ) dx (returning a nonlinear function), rather than a nonlinear function y -> f(y) (returning a linear operator). But that's probably not a super important distinction, to be honest.)
I'd want to be sure that this works correctly with:
- solvers that evaluate and store the result of
.vf(...)directly. -
BacksolveAdjoint - returning "complicated" linear operators like
lx.JacobianLinearOperator.
In principle all of those things should be solvable -- Diffrax allows the result of the vector field to be arbitrary -- I just think we'd want to explicitly test them.
I'd be happy to take a pull request on this!
I have a branch started for this, but I want to know the scope of the change you are looking for here. Having spent (a little) time thinking about it I see a couple options, in increasing impact on the package
- I can add support to ControlTerms to allow linear operators, which would really reduce the complexity of https://github.com/patrick-kidger/diffrax/pull/402, without changing much or adding many LoC (and introduce no breaking changes). This could be a weird edge case, like oh you can do weaklydiagonal or also just lx control (maybe not a big deal)
- Control terms must return linear operators (or maybe allow operators and arrays), this means the weaklydiagonal term is gone/deprecated and can be totally removed from the package (breaking change) and everything is just control terms with specifications (might require some work in term checking, but probably doable)
- Everything just becomes a linear operator (every
f(y) dx), which means all vector fields are now this way. This would include removing weaklydiagonal and changing existing terms (and would introduce breaking changes across the board). Seems like a very substantial refactor of the core of the Term design
I'm sure there are more nuanced or totally different options, but I'm just looking to get a feel on the scope you want with this change so I don't do a lot of unnecessary work.
For context, the simplest sort of approach to 1. is shown in https://github.com/patrick-kidger/diffrax/pull/434 (needs more tests/docs/whatever but the core idea of just making minimal change to allow lineax in control terms is there)
I really like the look of #434! I think this is pretty much exactly what I had in mind.
I think if we were designing from scratch then we'd probably go with option 2*, but since that's not the reality we live in, then I think maintaining backward compatibility is worthwhile.
That said I think it might be worth marking WeaklyDiagonalControlTerm with a PendingDeprecationWarning, just to gently encourage people to use this new thing instead. (And perhaps also remove it from the generated docs?) Subclassing terms to create new matrix-vector interactions was always an advanced thing to do, so I like that this new approach simplifies and standardises that.
* Why not option 3? I think we'd still need the AbstractTerm abstraction as a wrapper around lx.AbstractLinearOperator, because we need a way to wrap up multiple terms into one, and a place to put controls. Not a strong feeling though.