jax
jax copied to clipboard
checkify does not work in pytree constructors
Description
from jax.tree_util import register_pytree_node_class
import jax.numpy as jnp
from jax.experimental import checkify
import jax
@register_pytree_node_class
class Boxes:
def __init__(self, arr: jnp.ndarray):
self.arr = arr
checkify.check((self.arr >= 0).all(), "x")
def area(self) -> jnp.ndarray:
# checkify.check((self.arr >= 0).all(), "x") # this alone works
return (self.arr[:, 2] - self.arr[:, 0]) * (self.arr[:, 3] - self.arr[:, 1])
def tree_flatten(self):
return ((self.arr,), None)
@classmethod
def tree_unflatten(cls, _, children):
return cls(children[0])
def func(x: Boxes):
return x.area() + 3
func = checkify.checkify(func)
b = Boxes(jnp.float32([[1, 2, 3, 4]]))
print(b)
print(func(b)) # works
print(jax.jit(func)(b))
# => ValueError: Cannot abstractly evaluate a checkify.check which was not functionalized. This probably means you tried to stage (jit/scan/pmap/...) a `check` without functionalizing it through `checkify.checkify`.
The above code failed with the said error.
If I only use checkify in area() but not __init__(), then the above code can work.
What jax/jaxlib version are you using?
latest
Which accelerator(s) are you using?
No response
Additional System Info
No response
I'm pretty sure the issue is that jit unflattens the Boxes when passing the args into the checkified function, so the check in the __init__ is called outside of the checkify. This is pretty unintuitive but is what the error refers too (the check is not "functionalized").
You could disable the check when initializing in unflatten, such that it only runs when a user initializes the class but not when the JAX APIs unflatten the class. eg. this works:
@register_pytree_node_class
class Boxes:
def __init__(self, arr: jnp.ndarray, run_check=True):
self.arr = arr
if run_check:
checkify.check((self.arr >= 0).all(), "x")
...
@classmethod
def tree_unflatten(cls, _, children):
return cls(children[0], run_check=False)
but then you wouldn't be able to catch tree operations which produce invalid data like:
jax.tree_map(lambda x: x-2, Boxes(jnp.float32([[1, 2, 3, 4]])))
Not sure if we could make this "just work" (@mattjj?), but we should probably improve the error message to make it easier to track down where the "functionalization error" happened.
The issue is that checkify wraps computation in Error monad. It means that an instance initializer __init__ under checkify transformation returns Error type instead of NoneType type.
I faced the same issue the much simpler situation. In the code below I want to unbox monadic value but jit transformation does not allow. This is very sad actually since it poison all jitted call stack with Error monad. By the moment, the only way to unbox monadic value is unboxing in a non-jitted function. So, this is why checkify is an experimental stuff.
@checkify
def id(x):
check(x > 0, 'Should be posivitve.')
return x
@jit
def fn(x):
e, r = id(x)
e.throw()
return r
e, _ = id(1) # OK.
e, _ = fn(1) # Cannot abstractly evaluate a checkify.check which ...
Inability to unbox in jitted code is a quite fundamental and origins from the fact that jitted code is run on a device. Also, there is a related issue #278 about out-of-bound indexing.
In this case you can checkify again!
The issue here is that you can't throw an error without functionalizing that error effect, but you can do that "unboxing" (checkify) and "boxing" (throwing an error/adding an error effect) in a nested fashion. You can think of it as switching from regular python error handling (raising an error as a side-effect) and status-passing style (functions return their status). You can switch between these two models through "throwing errors" and "checkifying functions which throw errors".
jit is only compatible with the status-passing style (but hopefully it's easy to switch to that style when you need to). The example below is only two nested functions, but you can imagine a much more complex functions with much more nesting where you checkify only when you need to (eg. when you want to jit)
@checkify
def id(x):
check(x > 0, 'Should be posivitve.')
return x
@jit
@checkify
def fn(x):
e, r = id(x)
e.throw()
return r
e, _ = id(1) # OK.
e, _ = fn(1) # OK.
e, _ = fn(-1)
e.get()
# ... Should be positive