diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

Genuinely confused about the usage of dfx.RESULTS

Open johannahaffner opened this issue 1 year ago • 2 comments

Alright, this is a very minor issue and there is no time pressure at all. I don't get how dfx.RESULTS[result] works, and so far I work around it by accessing its _value. (Same thing for optimistix.)

import pytest
import diffrax as dfx

def dydt(t, y, args):
    return -y

solution = dfx.diffeqsolve(dfx.ODETerm(dydt), dfx.Tsit5(), 0, 10, 0.01, 10.)

with pytest.raises(ValueError):
    dfx.RESULTS[0]  # diffrax documentation specifies usage as diffrax.RESULTS[<integer>]
    dfx.RESULTS[solution.result._value]  # This also does not work
assert dfx.RESULTS[solution.result] == ''  # This is an empty string

status = int(solution.result._value)  # This is what I actually do

johannahaffner avatar Jun 12 '24 19:06 johannahaffner

Ah whoops, that's something that is out of date. I've just written #443 to try and improve that.

These days we don't pass around raw integers any more. Instead we use an eqx.Enumeration, which has the advantage of being human-readable by default, and of only supporting equality checking against other integers. (I ran into bugs with people doing stuff like if solution.result < 5 and getting it wrong / making it hard to make changes in backward-compatible ways.) So please don't unwrap the ._value to get at the secret integer hiding underneath :)

If you're after a human-readable error message then you'll get this when directly just printing out solution.result. On success you'll get something like diffrax._solution.RESULTS<>, which is a lack-of-message indicating success. Otherwise you'll get diffrax._solution.RESULTS<Some informative message here!>

If you ever want that raw string for the message then that's accessible at diffrax.RESULTS[solution.result], although only outside of JIT. I think needing this string on its own is pretty rare; I've certainly never had a use-case for it.

If you're looking to write JIT-compatible code that does one thing on error and another thing on failure, then do

jnp.where(solution.result == diffrax.RESULTS.successful, foo, bar)

The result of solution.result == diffrax.RESULTS.successful is a scalar JAX boolean array. It will be a tracer if you're inside of JIT.

(Also heads-up that diffrax.RESULTS.discrete_terminating_event_occurred also exists, and should usually also be treated as a success state, if you're using events in your code at all: (solution.result == diffrax.RESULTS.successful) | (solution.result == diffrax.RESULTS.discrete_terminating_event_occurred))

patrick-kidger avatar Jun 12 '24 19:06 patrick-kidger

Thank you, amazing! I occasionally brushed up against this when conditioning on the solution status, and it made me wonder what is up with that. I'll switch to the jnp.where solution!

On a related note, I'm writing my own Solution classes now and I'll base them on eqx.Enumeration now that I understand it better!

johannahaffner avatar Jun 12 '24 20:06 johannahaffner