chex icon indicating copy to clipboard operation
chex copied to clipboard

Allow for nested chex.chexify

Open nicow-elia opened this issue 1 year ago • 1 comments

Hello, I have a dilemma with chexify - consider the following code:

# If this is not commented out, the second test will fail
# If this is commented out, the first test will fail
@chex.chexify
@jax.jit
def log_safe(x: jnp.array) -> jnp.array:
    chex.assert_trees_all_equal(x > 0, jnp.ones_like(x, dtype=bool))
    return jnp.log(x)

@chex.chexify
@jax.jit
def combo_safe(x: jnp.array) -> jnp.array:
    chex.assert_trees_all_equal(x != 1, jnp.ones_like(x, dtype=bool))
    return log_safe(x) / (x - 1)


def test_log_safe() -> None:
    x = jnp.array([1.0, 2.0, 3.0, -1.0])
    with pytest.raises(Exception):
        log_safe(x)
        log_safe.wait_checks()

    x = jnp.array([1.0, 2.0, 3.0, 4.0])
    assert jnp.array_equal(log_safe(x), jnp.log(x))
    log_safe.wait_checks()

def test_combo_safe() -> None:
    x = jnp.array([1.0, 2.0, 3.0, 4.0])
    with pytest.raises(Exception):
        combo_safe(x)
        combo_safe.wait_checks()

    x = jnp.array([2.0, 3.0, 4.0, 5.0])
    assert jnp.array_equal(combo_safe(x), jnp.log(x) / (x - 1))
    combo_safe.wait_checks()

If I comment out the first chexify the test_log_safe test will fail with RuntimeError: Value assertions can only be called from functions wrapped with @chex.chexify. See the docs. which makes sense to me. However, once I add the decorator back in, the second test fails with RuntimeError: Nested @chexify wrapping is disallowed. Make sure that you only wrap the function at the outermost level.

A hack in this simple scenario would be to make two versions of the function, a log_safe without the chexify decorator and a log_safe_test = chex.chexify(log_safe) and only call the log_safe_test version during my test. However, that solution is pretty clumsy, especially if I have a lot of these scenarios. In a codebase that is fully end-to-end jax, that would mean all but the outermost function require this hack. Would it be possible to allow for nested chex.chexify where subsequent applications of the macro simply do nothing, or just raise a warning?

nicow-elia avatar Sep 13 '23 15:09 nicow-elia

I'm also finding the recent-ish change to disallow multiple chexify quite difficult since my graph of functions is not a tree with a single root but I still want chexify in this situation: f() calls g() calls h(), and g() calls h()

Edgeworth avatar Dec 02 '23 14:12 Edgeworth