Genuinely confused about the usage of dfx.RESULTS
Alright, this is a very minor issue and there is no time pressure at all. I don't get how dfx.RESULTS[result] works, and so far I work around it by accessing its _value. (Same thing for optimistix.)
import pytest
import diffrax as dfx
def dydt(t, y, args):
return -y
solution = dfx.diffeqsolve(dfx.ODETerm(dydt), dfx.Tsit5(), 0, 10, 0.01, 10.)
with pytest.raises(ValueError):
dfx.RESULTS[0] # diffrax documentation specifies usage as diffrax.RESULTS[<integer>]
dfx.RESULTS[solution.result._value] # This also does not work
assert dfx.RESULTS[solution.result] == '' # This is an empty string
status = int(solution.result._value) # This is what I actually do
Ah whoops, that's something that is out of date. I've just written #443 to try and improve that.
These days we don't pass around raw integers any more. Instead we use an eqx.Enumeration, which has the advantage of being human-readable by default, and of only supporting equality checking against other integers. (I ran into bugs with people doing stuff like if solution.result < 5 and getting it wrong / making it hard to make changes in backward-compatible ways.) So please don't unwrap the ._value to get at the secret integer hiding underneath :)
If you're after a human-readable error message then you'll get this when directly just printing out solution.result. On success you'll get something like diffrax._solution.RESULTS<>, which is a lack-of-message indicating success. Otherwise you'll get diffrax._solution.RESULTS<Some informative message here!>
If you ever want that raw string for the message then that's accessible at diffrax.RESULTS[solution.result], although only outside of JIT. I think needing this string on its own is pretty rare; I've certainly never had a use-case for it.
If you're looking to write JIT-compatible code that does one thing on error and another thing on failure, then do
jnp.where(solution.result == diffrax.RESULTS.successful, foo, bar)
The result of solution.result == diffrax.RESULTS.successful is a scalar JAX boolean array. It will be a tracer if you're inside of JIT.
(Also heads-up that diffrax.RESULTS.discrete_terminating_event_occurred also exists, and should usually also be treated as a success state, if you're using events in your code at all: (solution.result == diffrax.RESULTS.successful) | (solution.result == diffrax.RESULTS.discrete_terminating_event_occurred))
Thank you, amazing! I occasionally brushed up against this when conditioning on the solution status, and it made me wonder what is up with that. I'll switch to the jnp.where solution!
On a related note, I'm writing my own Solution classes now and I'll base them on eqx.Enumeration now that I understand it better!