probdiffeq
probdiffeq copied to clipboard
Issue with new impl.select approach
The newly introduced requirement (as part of the code refactorisation) of having to state the implementation AND the ode_shape a-priori can be slightly problematic.
If one is using an ode_solve inside a training loop in which data is being batched this requires impl.select to be defined locally within the method the "solve" method as the last batch can obviously be a differing size to the rest. This produces unsurprisingly produces a warning along the following lines...
_impl.py:103: UserWarning: An implementation has already been selected: 'isotropic'. warnings.warn(msg, stacklevel=1)
Which isn't code breaking, but makes me wonder if the API should be slightly adjusted such that you select the type i.e. dense, isotropic, etc a-priori, but the ode_shape is dynamic?
Thanks for the report!
Just so I understand correctly, what exactly do you find problematic? Is the behaviour off, or is it the warning that annoys you?
I don't believe the behaviour is off, but given the warning I presumed that there is an assumption being violated here. Furthermore, I don't think it is unreasonable / uncommon to think that the size could change due to batching.
Ah I see, thanks for clarifying!
I don't believe the behaviour is off, but given the warning I presumed that there is an assumption being violated here.
I understand. The warning might the wrong signal. There is nothing wrong with changing the state-space model factorisation; the intention of the warning was to raise awareness to the fact that something like impl.new_variable() behaves differently before and after the second change. But I think that is more or less clear, so the warning should be removed. What do you think?
Furthermore, I don't think it is unreasonable / uncommon to think that the size could change due to batching.
Probably. However, as far as I understand, the code would recompile anyway as soon as the size of the differential equation changes, in which case selecting a different backend (respectively, a backend with a different shape) would not lead to any performance degradation. Any thoughts on this? :)
the intention of the warning was to raise awareness to the fact that something like impl.new_variable() behaves differently before and after the second change.
I think that is important to know.
the code would recompile anyway as soon as the size of the differential equation changes,
Hmm, that's not ideal. I have some further questions about that, but I don't want to clutter github comments about something specific to me, so will drop you an email.
Looking forward to your email! :)
But if anyone else runs into this issue with a similar problem:
the code would recompile anyway as soon as the size of the differential equation changes
that is something all JAX code has to deal with, not just ProbDiffEq:
https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#jit-mechanics-tracing-and-static-variables
There might be special workarounds that one can implement with clever padding, for instance, but these will depend on the specific problem.
This issue seems resolved, so I'll close it. If there is something left to discuss, please reopen.