jax
jax copied to clipboard
BUG: checkify transformation doesn't work if one of the function arguments is not a valid JAX type
Description
The most obvious way to run into this would be to apply checkify
to a class method of a class that is not a valid JAX type.
def test_checkify_on_class_method():
class C:
def f(self, x):
checkify.check(x > 0, 'x must be positive')
return x
c_instance = C()
checked_f = checkify.checkify(C.f)
err, val = checked_f(c_instance, 1)
err.throw()
assert val == 1
with pytest.raises(ValueError):
err, val = checked_f(c_instance, -1)
err.throw()
The test above fails with
Value <test_checkify.test_checkify_on_class_method.<locals>.C object at 0x16a338250> with type <class 'test_checkify.test_checkify_on_class_method.<locals>.C'> is not a valid JAX type
We currently have a way to jit-compile class methods by marking self
as a static argument, perhaps checkify
constraints should be loosen if it's meant to be used in combination with jit
.
What jax/jaxlib version are you using?
0.3.16
Which accelerator(s) are you using?
CPU
Additional System Info
Mac
@LenaMartens Hi, perhaps you can comment on that? Thanks!
Yes please, I want to checkify object methods!
Any way to avoid this using partials or something similar, as is the case with jax.jit
?
Whoops, I seem to have missed this!
Yes, you can use a partial or lambda to get around this, or in the case of methods, transform the bound method in __init__
. This is similar to other transforms without static_args
(like vmap
). The reason why jit
/pmap
needs static_args
is because it caches compilation based on function identity, which means you can't use the lambda/partial trick. checkify
has no such cache.
Some examples how to make this work:
class C:
def __init__(self):
self.foo = checkify.checkify(self.foo) # works
def foo(self, x):
checkify.check(x > 0, 'x must be positive')
return x
def bar(self, x):
checkify.check(x > 0, 'x must be positive')
return x
c_instance = C()
err, val = c_instance.foo(1.)
err.throw()
checked_bar = checkify.checkify(lambda x: c_instance.bar(x)) # works (also works with functools.partial)
err, val = checked_bar(-1.)
err.throw()