jax icon indicating copy to clipboard operation
jax copied to clipboard

checkify does not work in pytree constructors

Open ppwwyyxx opened this issue 3 years ago • 1 comments

Description

from jax.tree_util import register_pytree_node_class
import jax.numpy as jnp
from jax.experimental import checkify
import jax

@register_pytree_node_class
class Boxes:
  def __init__(self, arr: jnp.ndarray):
    self.arr = arr
    checkify.check((self.arr >= 0).all(), "x")
  
  def area(self) -> jnp.ndarray:
    # checkify.check((self.arr >= 0).all(), "x") # this alone works
    return (self.arr[:, 2] - self.arr[:, 0]) * (self.arr[:, 3] - self.arr[:, 1])

  def tree_flatten(self):
    return ((self.arr,), None)
  
  @classmethod
  def tree_unflatten(cls, _, children):
    return cls(children[0])

def func(x: Boxes):
  return x.area() + 3

func = checkify.checkify(func)

b = Boxes(jnp.float32([[1, 2, 3, 4]]))
print(b)
print(func(b))  # works
print(jax.jit(func)(b))   
# =>  ValueError: Cannot abstractly evaluate a checkify.check which was not functionalized. This probably means you tried to stage (jit/scan/pmap/...) a `check` without functionalizing it through `checkify.checkify`.

The above code failed with the said error.

If I only use checkify in area() but not __init__(), then the above code can work.

What jax/jaxlib version are you using?

latest

Which accelerator(s) are you using?

No response

Additional System Info

No response

ppwwyyxx avatar Sep 09 '22 05:09 ppwwyyxx

I'm pretty sure the issue is that jit unflattens the Boxes when passing the args into the checkified function, so the check in the __init__ is called outside of the checkify. This is pretty unintuitive but is what the error refers too (the check is not "functionalized").

You could disable the check when initializing in unflatten, such that it only runs when a user initializes the class but not when the JAX APIs unflatten the class. eg. this works:

@register_pytree_node_class
class Boxes:
  def __init__(self, arr: jnp.ndarray, run_check=True):
    self.arr = arr
    if run_check:
      checkify.check((self.arr >= 0).all(), "x")
  
  ...
  
  @classmethod
  def tree_unflatten(cls, _, children):
    return cls(children[0], run_check=False) 

but then you wouldn't be able to catch tree operations which produce invalid data like:

jax.tree_map(lambda x: x-2,  Boxes(jnp.float32([[1, 2, 3, 4]])))

Not sure if we could make this "just work" (@mattjj?), but we should probably improve the error message to make it easier to track down where the "functionalization error" happened.

LenaMartens avatar Sep 09 '22 17:09 LenaMartens

The issue is that checkify wraps computation in Error monad. It means that an instance initializer __init__ under checkify transformation returns Error type instead of NoneType type.

I faced the same issue the much simpler situation. In the code below I want to unbox monadic value but jit transformation does not allow. This is very sad actually since it poison all jitted call stack with Error monad. By the moment, the only way to unbox monadic value is unboxing in a non-jitted function. So, this is why checkify is an experimental stuff.

@checkify
def id(x):
    check(x > 0, 'Should be posivitve.')
    return x

@jit
def fn(x):
    e, r = id(x)
    e.throw()
    return r

e, _ = id(1)  # OK.
e, _ = fn(1)  # Cannot abstractly evaluate a checkify.check which ...

Inability to unbox in jitted code is a quite fundamental and origins from the fact that jitted code is run on a device. Also, there is a related issue #278 about out-of-bound indexing.

daskol avatar Oct 19 '22 16:10 daskol

In this case you can checkify again! The issue here is that you can't throw an error without functionalizing that error effect, but you can do that "unboxing" (checkify) and "boxing" (throwing an error/adding an error effect) in a nested fashion. You can think of it as switching from regular python error handling (raising an error as a side-effect) and status-passing style (functions return their status). You can switch between these two models through "throwing errors" and "checkifying functions which throw errors". jit is only compatible with the status-passing style (but hopefully it's easy to switch to that style when you need to). The example below is only two nested functions, but you can imagine a much more complex functions with much more nesting where you checkify only when you need to (eg. when you want to jit)

@checkify
def id(x):
    check(x > 0, 'Should be posivitve.')
    return x

@jit
@checkify
def fn(x):
    e, r = id(x)
    e.throw()
    return r

e, _ = id(1)  # OK.
e, _ = fn(1)  # OK.
e, _ = fn(-1)
e.get()
# ... Should be positive

LenaMartens avatar Oct 21 '22 18:10 LenaMartens