diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

Fix assert error for type of `keep_step`

Open mstoelzle opened this issue 2 years ago • 4 comments

When I am running a normal integration such as

import diffrax
ode_term = ODETerm(ode_fn)
sol = diffrax.diffeqsolve(
    ode_term,
    diffrax.Euler(),
    0.0,  # initial time
    1.0,  # final time
    1e-4,  # time step
    x_init_bt.astype(jnp.float64)[0, :],  # initial state
    max_steps=20000,
)

I will get an error similar to

  File "/home/mstolzle/sources/learning-representations-from-first-principle-dynamics/src/tasks/fp_dynamics.py", line 251, in forward_fn
    sol = ode_solve_fn(
          ^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/lrfpd/lib/python3.11/site-packages/equinox/_jit.py", line 107, in __call__
    return self._call(False, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/lrfpd/lib/python3.11/site-packages/equinox/_jit.py", line 103, in _call
    out = self._cached(dynamic, static)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/lrfpd/lib/python3.11/site-packages/diffrax/integrate.py", line 824, in diffeqsolve
    final_state, aux_stats = adjoint.loop(
                             ^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/lrfpd/lib/python3.11/site-packages/diffrax/adjoint.py", line 286, in loop
    final_state = self._loop(
                  ^^^^^^^^^^^
  File "/opt/miniconda3/envs/lrfpd/lib/python3.11/site-packages/diffrax/integrate.py", line 424, in loop
    filter_state = eqx.filter_eval_shape(body_fun, init_state)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/lrfpd/lib/python3.11/site-packages/diffrax/integrate.py", line 252, in body_fun
    assert jnp.result_type(keep_step) is jnp.dtype(bool)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

When I print jnp.result_type(keep_step), I get bool instead of jnp.dtype(bool).

I would like to stress that this issue only appears for certain ode_fn. I haven't quite figured out yet which change/property of the ode_fn causes this error to occur.

Still, this backwards-compatible change should work for any case.

mstoelzle avatar Sep 13 '23 13:09 mstoelzle

What version of JAX and what version of NumPy are you using?

patrick-kidger avatar Sep 13 '23 15:09 patrick-kidger

What version of JAX and what version of NumPy are you using?

I am using python 3.11, numpy 1.23.5, jax 0.4.14, equinox 0.10.11, diffrax 0.4.1

mstoelzle avatar Sep 13 '23 16:09 mstoelzle

Hmm. I'm not able to easily reproduce this with those versions. It should always be the case that jnp.result_type returns a numpy dtype. You say this only arises for certain ode_fn. Can you provide a MWE?

patrick-kidger avatar Sep 13 '23 16:09 patrick-kidger

Hi @mstoelzle and @patrick-kidger,

I had the same error. assert jnp.result_type(keep_step) == jnp.dtype(bool) passed while assert jnp.result_type(keep_step) is jnp.dtype(bool) threw an error. It only occurred when I loaded my model via pickle, so I figured that through the pickle.load some code was loaded that caused the two types not to be identical objects.

However, when I initialized the model beforehand as I did during training, loading the checkpoint and subsequently using diffrax worked without any errors. I guess this might be because in this case no additional code needed to be loaded through the pickle.load call.

I hope this helps to resolve the issues you are having, @mstoelzle.

Best regards, Vincent

VincentStimper avatar Oct 27 '23 08:10 VincentStimper