equinox icon indicating copy to clipboard operation
equinox copied to clipboard

How to merge equinox pretty-printing and jax debug printing?

Open LouisDesdoigts opened this issue 5 months ago • 2 comments

So I'm trying to do two things here:

  1. Raise a error through the jit boundary using eqx.error_if.
  2. Print a pytree using the jax.debug.print function.

I'm looking to do this because some Nans are arising during my training that are difficult to isolate. I can't use the usual jax debug_nans flag as some of data naturally has Nans present.

Presently I am creating a boolean pytree that checks for any nans on the leaves, and ideally I would be able to print the actually boolean value as opposed to the usual Traced<ShapedArray(bool[])>, which would be done using the jax debug print.

Extending your error_if example:


@eqx.filter_jit
def f(x):
    bool_tree = jax.tree_map(lambda x: np.isnan(x).any(), x)
    vals = np.array(jax.tree_util.tree_flatten(bool_tree)[0])
    msg = "Nan found in tree:\n" + eqx.tree_pformat(bool_tree, short_arrays=False)
    x = eqx.error_if(x, vals.sum() > 0, msg)
    return x

pytree = (np.zeros(3), np.zeros(5))
_ = f(pytree)

nan_pytree = eqx.tree_at(lambda pytree: pytree[0], pytree, np.zeros(3).at[2].set(np.nan))
_ = f(nan_pytree)

However this doesn't print the actual array values, hence the desire for some interface with jax.debug.print. Is there some simple way to achieve this?

Thanks in advance!

LouisDesdoigts avatar Jan 22 '24 10:01 LouisDesdoigts