probdiffeq icon indicating copy to clipboard operation
probdiffeq copied to clipboard

Issue with new impl.select approach

Open adam-hartshorne opened this issue 1 year ago • 5 comments

The newly introduced requirement (as part of the code refactorisation) of having to state the implementation AND the ode_shape a-priori can be slightly problematic.

If one is using an ode_solve inside a training loop in which data is being batched this requires impl.select to be defined locally within the method the "solve" method as the last batch can obviously be a differing size to the rest. This produces unsurprisingly produces a warning along the following lines...

_impl.py:103: UserWarning: An implementation has already been selected: 'isotropic'. warnings.warn(msg, stacklevel=1)

Which isn't code breaking, but makes me wonder if the API should be slightly adjusted such that you select the type i.e. dense, isotropic, etc a-priori, but the ode_shape is dynamic?

adam-hartshorne avatar Oct 29 '23 17:10 adam-hartshorne

Thanks for the report!

Just so I understand correctly, what exactly do you find problematic? Is the behaviour off, or is it the warning that annoys you?

pnkraemer avatar Oct 30 '23 07:10 pnkraemer

I don't believe the behaviour is off, but given the warning I presumed that there is an assumption being violated here. Furthermore, I don't think it is unreasonable / uncommon to think that the size could change due to batching.

adam-hartshorne avatar Oct 30 '23 10:10 adam-hartshorne

Ah I see, thanks for clarifying!

I don't believe the behaviour is off, but given the warning I presumed that there is an assumption being violated here.

I understand. The warning might the wrong signal. There is nothing wrong with changing the state-space model factorisation; the intention of the warning was to raise awareness to the fact that something like impl.new_variable() behaves differently before and after the second change. But I think that is more or less clear, so the warning should be removed. What do you think?

Furthermore, I don't think it is unreasonable / uncommon to think that the size could change due to batching.

Probably. However, as far as I understand, the code would recompile anyway as soon as the size of the differential equation changes, in which case selecting a different backend (respectively, a backend with a different shape) would not lead to any performance degradation. Any thoughts on this? :)

pnkraemer avatar Oct 30 '23 10:10 pnkraemer

the intention of the warning was to raise awareness to the fact that something like impl.new_variable() behaves differently before and after the second change.

I think that is important to know.

the code would recompile anyway as soon as the size of the differential equation changes,

Hmm, that's not ideal. I have some further questions about that, but I don't want to clutter github comments about something specific to me, so will drop you an email.

adam-hartshorne avatar Oct 30 '23 10:10 adam-hartshorne

Looking forward to your email! :)

But if anyone else runs into this issue with a similar problem:

the code would recompile anyway as soon as the size of the differential equation changes

that is something all JAX code has to deal with, not just ProbDiffEq:

https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#jit-mechanics-tracing-and-static-variables

There might be special workarounds that one can implement with clever padding, for instance, but these will depend on the specific problem.

pnkraemer avatar Nov 04 '23 13:11 pnkraemer

This issue seems resolved, so I'll close it. If there is something left to discuss, please reopen.

pnkraemer avatar Oct 16 '24 06:10 pnkraemer