qiskit-dynamics icon indicating copy to clipboard operation
qiskit-dynamics copied to clipboard

Update JAX fixed-step solver templates to be compilable w.r.t. t_span and t_eval args

Open DanPuzzuoli opened this issue 2 years ago • 0 comments

Summary

Make the fixed step solvers compileable/differentiable with respect to the t_span and t_eval arguments.

This isn't an extremely urgent issue, however it would still be very nice to round out the features of these solvers.

Details

Issue #122 outlines a bug that is ultimately due to the fact that the JAX solvers in dynamics cannot be compiled if t_eval is not None. The fix PR, #125, resolves this issue by updating jax_odeint and the diffrax solver wrapper so that they can be compiled with respect to t_eval.

As described in #122 however, updating the fixed step JAX solvers built in dynamics to be compilable with respect to both t_span and t_eval is non-trivial due to their looping structure being dependent on the values of t_span and t_eval. As a result, the fix #125 is only partial: in the case of JAX fixed step solvers, the problem is simply avoided rather than being fundamentally fixed.

To make the fixed step solvers fully compilable/differentiable with respect to the t_span and t_eval arguments, the functions fixed_step_solver_template_jax and fixed_step_lmde_solver_parallel_template_jax need to be updated to use more advanced JAX control flow. Preserving differentiability with respect to other parameters may also require defining custom differentiation rules (vjp and jvp rules).

DanPuzzuoli avatar Sep 01 '22 19:09 DanPuzzuoli