diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

Example for solving forced ODE

Open fhchl opened this issue 2 years ago • 7 comments

Very cool work!

It would be great to have an example on how to solve an ODE that is directly forced by some signal x(t), e.g. a forced mass-spring-damper

m y'' + r y' + k y = x(t).

Do I understand correctly, that the controlled ODEs are "forced" by the derivative of x(t)?

fhchl avatar Feb 08 '22 23:02 fhchl

Right, so "controlled differential equations" are a specific notion, and are written like dy(t) = f(y(t)) dx(t); in some sense they are indeed forced by the derivative of x. This is a concept coming out of rough path theory, and is particularly interesting for the case where x(t) = [t, w(t)] is a time-augmented Brownian motion. Indeed you've probably see stuff like dy(t) = μ(t, y(t)) dt + σ(t, y(t)) dw(t), i.e. an SDE, written down before.


If you want it to be directly forced as you describe, then the answer is to encode this time dependence directly into the vector field:

def vector_field(t, y, args):
    signal = x(t)
    ...
term = diffrax.ODETerm(vector_field)
diffrax.diffeqsolve(term, ...)

If you think this would be valuable and would be happy to add it, then I'd be very happy to accept a PR adding a short example on this. (e.g. using the Stiff ODE example as a starting point.) I definitely recognise that this kind of direct forcing appears more frequently in e.g. the engineering literature.

patrick-kidger avatar Feb 08 '22 23:02 patrick-kidger

(If you're curious to know more about CDEs than I'd recommend Chapter 3 of the recently-released On Neural Differential Equations for an introduction. There's other references available too; I just happen to have written that one recently ;) )

patrick-kidger avatar Feb 08 '22 23:02 patrick-kidger

I'll definitely have a look at the thesis, thanks!

If you want it to be directly forced as you describe, then the answer is to encode this time dependence directly into the vector field

The most straightforward way would be to include the forcing term in vector_field - if one had a functional representation of it. Let's say that the forcing signal was measured so it is given as some collection of samples. One could use some of the interpolation schemes to get the functional representation and then include that into vector_field.

Do I understand it correctly, that this would also be an option for the CDE, in case the control is differentiable? Something along

def vector_field(t, y, args):
   dx = control.derivative(t)
    ...
term = diffrax.ODETerm(vector_field)
diffrax.diffeqsolve(term, ...)

Is it correct to assume then, that the ControlTerm abstraction is especially useful for the SDEs?

fhchl avatar Feb 09 '22 16:02 fhchl

The most straightforward way would be to include the forcing term in vector_field - if one had a functional representation of it. Let's say that the forcing signal was measured so it is given as some collection of samples. One could use some of the interpolation schemes to get the functional representation and then include that into vector_field.

Yep! So you end up with

def vector_field(t, y, args):
    x = control.evaluate(t)
    ...

Do I understand it correctly, that this would also be an option for the CDE, in case the control is differentiable? Something along

Yep. In fact this is already built into Diffrax and wouldn't need to be (re)implemented manually. See ControlTerm.to_ode.

Is it correct to assume then, that the ControlTerm abstraction is especially useful for the SDEs?

I think the two main uses are (a) SDEs and (b) CDEs done via dx rather than x.

(FWIW I don't know of a strong reason to prefer doing CDEs one way over the other. I know of a short list of advantages and disadvantages, but none of them are that big of a deal.)

patrick-kidger avatar Feb 09 '22 17:02 patrick-kidger

Got it!

I have been using Jax so far to fit physical ODE models. I could cook up a little example for that use case with the approach discussed above.

fhchl avatar Feb 10 '22 14:02 fhchl

The most straightforward way would be to include the forcing term in vector_field - if one had a functional representation of it. Let's say that the forcing signal was measured so it is given as some collection of samples. One could use some of the interpolation schemes to get the functional representation and then include that into vector_field.

Yep! So you end up with

def vector_field(t, y, args):
    x = control.evaluate(t)
    ...

Great work with the library!

I have a use-case where I have an ODE w/ a (static) parameterized forcing function. I wanted to double check a few things

  • I think I can pass the parameters to the forcing function in through args and have a eval = forcing_function(args, t) call. That seems perfectly reasonable given the above discussion.
  • It also seems safe to assume that we can calculate $\partial vector_field / \partial args$ because args is explicitly in the function call. Is that the case? Is the vector_field differentiable wrt args?
  • If so, then we should be able to differentiate through the diffeqsolve wrt args, correct?

Thanks in advance!

joglekara avatar May 03 '22 22:05 joglekara

Yep, that's exactly right.

patrick-kidger avatar May 04 '22 08:05 patrick-kidger