diffrax
diffrax copied to clipboard
example on solving system of ODEs
Hi, I am absolutely new to diffrax and trying to get the basics of it (great work by the way).
I am a bit confused concerning the syntax for solving systems of ODEs. In the documentation discussing Terms you mention you can solve an Hamiltonian system (coupled ODEs) as a 2-tuple of diffrax.ODETerm. Consider as a minimal example a simple falling object: $$\dot{y}(t) = v(t) \quad \dot{v}(t) = -g$$ with g the gravitational constant. What is the correct way of expressing this? I tried the most likely wrong form:
def f1(t,y,args): return v
def f2(t,v,args): return -g
term1 = ODETerm(f1) term2 = ODETerm(f2) solver = SemiImplicitEuler() solution = diffeqsolve(terms= (term1, term2), solver= solver, t0=0., t1=1. ,dt0=0.02, y0=(0., 5.) )
And I get: TypeError: unsupported operand type(s) for *: 'DynamicJaxprTracer' and 'function'
That means I am not providing the input in the right form. Could you provide an example on how to implement this correctly ?
Thanks !
I'm afraid your code isn't something that can be run standalone -- if you can try to provide a snippet of code that I can copy-paste to understand what you're doing. You can get code-formatting and syntax highlighting by formatting it as:
```python
def my_function(x):
...
```
One thing that does jump out at me from what you've written is that in f1
you are returning v
, which is not an argument to the function. Whatever this other quantity may be the problem. Probably what you're trying to express is:
def dy_dt(t, v, args):
return v
def dv_dt(t, y, args):
return -g
Does this help?
It does help, and actually yes it solved the problem. I did not understand the correct syntax to define the input functions. Thanks for the help !
I do have a follow-up question, concerning the use of saveat
. Again a minimal problem:
def f(t, y, args):
return -y
term = ODETerm(f)
solver = Euler()
saveat = diffrax.SaveAt(steps=True)
solution = diffeqsolve(term, solver, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat)
when printing the result, the array dimension does not correspond to the number of time steps and contain inf values:
DeviceArray([0.9 , 0.81 , 0.729, ..., inf, inf, inf], dtype=float32, weak_type=True)
with dimension (4096,)
. Is this an expected behaviour ? I understood that by choosing steps =True
in SaveAt()
corresponds to savinf the value of f
at each time step.
Yep, this is expected. In JAX, all arrays must have a shape known at compile (JIT) time. But the number of steps may be dynamic (due to adaptive step size controllers), which isn't known until runtime.
As a result the approach necessarily taken is to allocate an array of a size equal to the maximum number of steps (diffeqsolve(..., max_steps=...)
), and then only populate the first elements of it that are actually used. (And fill the rest with infinities.)
Note that if you are using a fixed step size then you can choose to bake the number of steps in at compile time, via diffeqsolve(..., stepsize_controller=ConstantStepSize(compile_steps=True))
. This will produce a solution without the extra padding. Of course, it (a) only works for constant step sizes, and (b) will mean that JAX recompiles your code if the number of steps you need to take ever changes. (e.g. you change the value of dt0
.)
maybe worth mentioning - solution.stats["num_accepted_steps"]
is the number of actual steps taken by the solver.
this can be useful when, for example, computing a solution to a system with a known steady-state solution (like in this example in the docs) and you want to trim the inf
values in the solutions's saved steps.
please correct me if i'm wrong 🙂