qiskit-dynamics
qiskit-dynamics copied to clipboard
Update JAX fixed-step solver templates to be compilable w.r.t. t_span and t_eval args
Summary
Make the fixed step solvers compileable/differentiable with respect to the t_span
and t_eval
arguments.
This isn't an extremely urgent issue, however it would still be very nice to round out the features of these solvers.
Details
Issue #122 outlines a bug that is ultimately due to the fact that the JAX solvers in dynamics cannot be compiled if t_eval is not None
. The fix PR, #125, resolves this issue by updating jax_odeint
and the diffrax solver wrapper so that they can be compiled with respect to t_eval
.
As described in #122 however, updating the fixed step JAX solvers built in dynamics to be compilable with respect to both t_span
and t_eval
is non-trivial due to their looping structure being dependent on the values of t_span
and t_eval
. As a result, the fix #125 is only partial: in the case of JAX fixed step solvers, the problem is simply avoided rather than being fundamentally fixed.
To make the fixed step solvers fully compilable/differentiable with respect to the t_span
and t_eval
arguments, the functions fixed_step_solver_template_jax
and fixed_step_lmde_solver_parallel_template_jax
need to be updated to use more advanced JAX control flow. Preserving differentiability with respect to other parameters may also require defining custom differentiation rules (vjp
and jvp
rules).