jax icon indicating copy to clipboard operation
jax copied to clipboard

BUG: checkify transformation doesn't work if one of the function arguments is not a valid JAX type

Open hr0nix opened this issue 1 year ago • 1 comments

Description

The most obvious way to run into this would be to apply checkify to a class method of a class that is not a valid JAX type.

def test_checkify_on_class_method():
    class C:
        def f(self, x):
            checkify.check(x > 0, 'x must be positive')
            return x
    c_instance = C()
    checked_f = checkify.checkify(C.f)
    err, val = checked_f(c_instance, 1)
    err.throw()
    assert val == 1

    with pytest.raises(ValueError):
        err, val = checked_f(c_instance, -1)
        err.throw()

The test above fails with

Value <test_checkify.test_checkify_on_class_method.<locals>.C object at 0x16a338250> with type <class 'test_checkify.test_checkify_on_class_method.<locals>.C'> is not a valid JAX type

We currently have a way to jit-compile class methods by marking self as a static argument, perhaps checkify constraints should be loosen if it's meant to be used in combination with jit.

What jax/jaxlib version are you using?

0.3.16

Which accelerator(s) are you using?

CPU

Additional System Info

Mac

hr0nix avatar Aug 13 '22 15:08 hr0nix

@LenaMartens Hi, perhaps you can comment on that? Thanks!

hr0nix avatar Sep 13 '22 11:09 hr0nix

Yes please, I want to checkify object methods!

ingmarschuster avatar Oct 11 '22 10:10 ingmarschuster

Any way to avoid this using partials or something similar, as is the case with jax.jit?

ingmarschuster avatar Oct 11 '22 10:10 ingmarschuster

Whoops, I seem to have missed this!

Yes, you can use a partial or lambda to get around this, or in the case of methods, transform the bound method in __init__. This is similar to other transforms without static_args (like vmap). The reason why jit/pmap needs static_args is because it caches compilation based on function identity, which means you can't use the lambda/partial trick. checkify has no such cache.

Some examples how to make this work:

class C:
  def __init__(self):
    self.foo = checkify.checkify(self.foo)  # works 

  def foo(self, x):
    checkify.check(x > 0, 'x must be positive')
    return x
  
  def bar(self, x):
    checkify.check(x > 0, 'x must be positive')
    return x

c_instance = C()
err, val = c_instance.foo(1.)
err.throw()
checked_bar = checkify.checkify(lambda x: c_instance.bar(x)) # works (also works with functools.partial)
err, val = checked_bar(-1.)
err.throw()

LenaMartens avatar Oct 14 '22 18:10 LenaMartens