jaxquantum
jaxquantum copied to clipboard
Automatically jit certain functions in jaxquantum
Previously, we were worried about vmapping jitted functions, but it seems that is not an issue, as long as there is a final jit.
In solvers.py we can jit solve() with a static argument for the ODE function f. That should cover both sesolve and mesolve I think.
What else?