diffrax
diffrax copied to clipboard
safe_zip() error when passing arrays into args
Hi @patrick-kidger thank you for all your work. I have an ODE in JAX using scan that I am trying to port over to Diffrax to learn the library.
I have a vector field with 1 variable and defining my field / equinox call with ([self], t, y, args) and my return as new_y
x1 = Arr x2 = Arr assert len(x1) == len(x2)
t0 = 0, t1 = len(x1) - 1
y0: float args: tuple[Arr, Arr, float, float]
I am getting this error in runge_kutta.py
ks = ty_map(lambda x, xs: xs.at[stage_index].set(x), ki, ks)
....
File "/Users/alessandrolatif/watz/venv/lib/python3.11/site-packages/jax/_src/numpy/util.py", line 398, in _broadcast_to
for arr_d, shape_d in safe_zip(arr_shape, shape_tail))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: safe_zip() argument 2 is shorter than argument 1
I am unsure what I am not understanding to get this to work and would be grateful for your guidance.
Can you provide a MWE demonstrating the problem?
a1, a2, a3 are arguments. some of these are non-static arrays required in the ODE with 1 variable. I saw in the docs the reference to static being a requirement. If so, why? and how do i get around it or rethink my problem?
import jax.numpy as jnp
import diffrax as drx
def f(t, y, args):
x1, x2, x3 = args
new_y = y * x1 + x2 * jnp.sin(x3 * t)
return new_y
a1 = jnp.array([1,2,3])
a2 = jnp.array([4,5,6])
a3 = 2
drx.diffeqsolve(
drx.ODETerm(f),
drx.Tsit5(),
y0=3,
t0=0,
t1=len(a2) - 1,
dt0=1,
args=(a1, a2, a3)
)
The error message could be better, but the problem is that whilst your y0
is a scalar, the return value from f
is an array of shape (3,)
. These have the same shape as each other.
ok i understand.
so i am trying to get the value of the variables a1 and a2 at t
.
new_y = y * x1[t]
these values indexed at t
are required by the equation.
the problem is that i cannot index like this as t
is a float.
why do i feel like what I am trying to do is incorrect? or not in the scope of Diffrax?
im working with discrete sensor data sampled at 1HZ and am trying to implement the below equation (which I have in jax scan).
Use saveat
to specify where you want to get the output.