Fix assert error for type of `keep_step`
When I am running a normal integration such as
import diffrax
ode_term = ODETerm(ode_fn)
sol = diffrax.diffeqsolve(
ode_term,
diffrax.Euler(),
0.0, # initial time
1.0, # final time
1e-4, # time step
x_init_bt.astype(jnp.float64)[0, :], # initial state
max_steps=20000,
)
I will get an error similar to
File "/home/mstolzle/sources/learning-representations-from-first-principle-dynamics/src/tasks/fp_dynamics.py", line 251, in forward_fn
sol = ode_solve_fn(
^^^^^^^^^^^^^
File "/opt/miniconda3/envs/lrfpd/lib/python3.11/site-packages/equinox/_jit.py", line 107, in __call__
return self._call(False, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/envs/lrfpd/lib/python3.11/site-packages/equinox/_jit.py", line 103, in _call
out = self._cached(dynamic, static)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/envs/lrfpd/lib/python3.11/site-packages/diffrax/integrate.py", line 824, in diffeqsolve
final_state, aux_stats = adjoint.loop(
^^^^^^^^^^^^^
File "/opt/miniconda3/envs/lrfpd/lib/python3.11/site-packages/diffrax/adjoint.py", line 286, in loop
final_state = self._loop(
^^^^^^^^^^^
File "/opt/miniconda3/envs/lrfpd/lib/python3.11/site-packages/diffrax/integrate.py", line 424, in loop
filter_state = eqx.filter_eval_shape(body_fun, init_state)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniconda3/envs/lrfpd/lib/python3.11/site-packages/diffrax/integrate.py", line 252, in body_fun
assert jnp.result_type(keep_step) is jnp.dtype(bool)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
When I print jnp.result_type(keep_step), I get bool instead of jnp.dtype(bool).
I would like to stress that this issue only appears for certain ode_fn. I haven't quite figured out yet which change/property of the ode_fn causes this error to occur.
Still, this backwards-compatible change should work for any case.
What version of JAX and what version of NumPy are you using?
What version of JAX and what version of NumPy are you using?
I am using python 3.11, numpy 1.23.5, jax 0.4.14, equinox 0.10.11, diffrax 0.4.1
Hmm. I'm not able to easily reproduce this with those versions. It should always be the case that jnp.result_type returns a numpy dtype.
You say this only arises for certain ode_fn. Can you provide a MWE?
Hi @mstoelzle and @patrick-kidger,
I had the same error. assert jnp.result_type(keep_step) == jnp.dtype(bool) passed while assert jnp.result_type(keep_step) is jnp.dtype(bool) threw an error. It only occurred when I loaded my model via pickle, so I figured that through the pickle.load some code was loaded that caused the two types not to be identical objects.
However, when I initialized the model beforehand as I did during training, loading the checkpoint and subsequently using diffrax worked without any errors. I guess this might be because in this case no additional code needed to be loaded through the pickle.load call.
I hope this helps to resolve the issues you are having, @mstoelzle.
Best regards, Vincent