Hi,
I can't run the demos presented in the getting started section: https://docs.kidger.site/diffrax/#quick-example
I just copied the code from the site and get an error.
I installed diffrax through conda and checked, that the versions are up to date:
Python - 3.11.8
diffrax - 0.5.1
jax - 0.4.27
equinox - 0.11.4
so I should fulfill the requirements listed.
There is a JAX issue on this topic, so I was wondering, if there is way around that.
What am I missing here?
Thanks.
Traceback (most recent call last):
File "/home/soeren-nagel/Projects/co-evo-social-dynamics/demos/test.py", line 10, in
solution = diffeqsolve(term, solver, t0=0, t1=1, dt0=0.1, y0=y0)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/soeren-nagel/miniconda3/envs/social_space/lib/python3.11/site-packages/equinox/_jit.py", line 206, in call
return self._call(False, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/soeren-nagel/miniconda3/envs/social_space/lib/python3.11/site-packages/equinox/_module.py", line 1053, in call
return self.func(self.self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/soeren-nagel/miniconda3/envs/social_space/lib/python3.11/site-packages/equinox/_jit.py", line 200, in _call
out = self._cached(dynamic_donate, dynamic_nodonate, static)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/soeren-nagel/miniconda3/envs/social_space/lib/python3.11/site-packages/diffrax/_integrate.py", line 993, in diffeqsolve
final_state, aux_stats = adjoint.loop(
^^^^^^^^^^^^^
File "/home/soeren-nagel/miniconda3/envs/social_space/lib/python3.11/site-packages/diffrax/_adjoint.py", line 292, in loop
final_state = self._loop(
^^^^^^^^^^^
File "/home/soeren-nagel/miniconda3/envs/social_space/lib/python3.11/site-packages/diffrax/_integrate.py", line 509, in loop
_, traced_jump, traced_result = eqx.filter_eval_shape(body_fun_aux, init_state)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/soeren-nagel/miniconda3/envs/social_space/lib/python3.11/site-packages/diffrax/_integrate.py", line 305, in body_fun_aux
(y, y_error, dense_info, solver_state, solver_result) = solver.step(
^^^^^^^^^^^^
File "/home/soeren-nagel/miniconda3/envs/social_space/lib/python3.11/site-packages/diffrax/_solver/runge_kutta.py", line 1149, in step
final_val = eqxi.while_loop(
^^^^^^^^^^^^^^^^
File "/home/soeren-nagel/miniconda3/envs/social_space/lib/python3.11/site-packages/equinox/internal/_loop/loop.py", line 107, in while_loop
return checkpointed_while_loop(
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/soeren-nagel/miniconda3/envs/social_space/lib/python3.11/site-packages/equinox/internal/loop/checkpointed.py", line 249, in checkpointed_while_loop
final_val = _checkpointed_while_loop(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/soeren-nagel/miniconda3/envs/social_space/lib/python3.11/site-packages/equinox/internal/_loop/checkpointed.py", line 270, in _checkpointed_while_loop
return while_loop(cond_fun, _body_fun, init_val)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/soeren-nagel/miniconda3/envs/social_space/lib/python3.11/contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "/home/soeren-nagel/miniconda3/envs/social_space/lib/python3.11/site-packages/equinox/internal/_loop/checkpointed.py", line 268, in
_body_fun = lambda x: body_fun(x) # hashable wrapper; JAX issue #13554
^^^^^^^^^^^
ValueError: safe_map() argument 2 is shorter than argument 1
Switch to a different version of JAX. You're encountering a known bug in JAX version 0.4.27. :)
Thanks for the quick help. Reverting to a different version fixed everything.
This looks like an amazing project. I can't wait play around a bit more.