diffrax
diffrax copied to clipboard
Added JumpStepWrapper
Hi Patrick,
I factored the jump_ts and step_ts out of the PIDController into JumpStepWrapper (I'm not very set on this name, lmk if you have ideas). I also made it behave as we discussed in #483. In particular, the following three rules are maintained:
- We always have
t1-t0 <= prev_dt(this is checked viaeqx.error_if), with inequality only if the step was clipped or if we hit the end of the integration interval (we do not explicitly check for that). - If the step was accepted, then
next_dtmust be>=prev_dt. - If the step was rejected, then
next_dtmust be< t1-t0.
We achieve this in a very simple way here: https://github.com/patrick-kidger/diffrax/blob/78b122adf39b2f8d26a79d0ac239a2fb675653a1/diffrax/_step_size_controller/jump_step_wrapper.py#L119-L123
The next step is to add a parameter JumpStepWrapper.revisit_rejected_steps which does what you expect. That will appear in a future commit in this same PR.