diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

safe_zip() error when passing arrays into args

Open alexlatif opened this issue 1 year ago • 5 comments

Hi @patrick-kidger thank you for all your work. I have an ODE in JAX using scan that I am trying to port over to Diffrax to learn the library.

I have a vector field with 1 variable and defining my field / equinox call with ([self], t, y, args) and my return as new_y

x1 = Arr x2 = Arr assert len(x1) == len(x2)

t0 = 0, t1 = len(x1) - 1

y0: float args: tuple[Arr, Arr, float, float]

I am getting this error in runge_kutta.py

ks = ty_map(lambda x, xs: xs.at[stage_index].set(x), ki, ks)
....
  File "/Users/alessandrolatif/watz/venv/lib/python3.11/site-packages/jax/_src/numpy/util.py", line 398, in _broadcast_to
    for arr_d, shape_d in safe_zip(arr_shape, shape_tail))
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: safe_zip() argument 2 is shorter than argument 1

I am unsure what I am not understanding to get this to work and would be grateful for your guidance.

alexlatif avatar Jun 24 '23 22:06 alexlatif

Can you provide a MWE demonstrating the problem?

patrick-kidger avatar Jun 25 '23 01:06 patrick-kidger

a1, a2, a3 are arguments. some of these are non-static arrays required in the ODE with 1 variable. I saw in the docs the reference to static being a requirement. If so, why? and how do i get around it or rethink my problem?

import jax.numpy as jnp
import diffrax as drx

def f(t, y, args):
    x1, x2, x3 = args
    new_y = y * x1 + x2 * jnp.sin(x3 * t)
    return new_y

a1 = jnp.array([1,2,3])
a2 = jnp.array([4,5,6])
a3 = 2

drx.diffeqsolve(
    drx.ODETerm(f),
    drx.Tsit5(),
    y0=3,
    t0=0,
    t1=len(a2) - 1,
    dt0=1,
    args=(a1, a2, a3)
)

alexlatif avatar Jun 25 '23 02:06 alexlatif

The error message could be better, but the problem is that whilst your y0 is a scalar, the return value from f is an array of shape (3,). These have the same shape as each other.

patrick-kidger avatar Jun 25 '23 03:06 patrick-kidger

ok i understand.

so i am trying to get the value of the variables a1 and a2 at t.

new_y = y * x1[t]

these values indexed at t are required by the equation.

the problem is that i cannot index like this as t is a float.

why do i feel like what I am trying to do is incorrect? or not in the scope of Diffrax?

im working with discrete sensor data sampled at 1HZ and am trying to implement the below equation (which I have in jax scan).

Screenshot 2023-06-25 at 1 02 20 AM

alexlatif avatar Jun 25 '23 05:06 alexlatif

Use saveat to specify where you want to get the output.

patrick-kidger avatar Jun 25 '23 23:06 patrick-kidger