equinox
equinox copied to clipboard
How to merge equinox pretty-printing and jax debug printing?
So I'm trying to do two things here:
- Raise a error through the jit boundary using
eqx.error_if
. - Print a pytree using the
jax.debug.print
function.
I'm looking to do this because some Nans are arising during my training that are difficult to isolate. I can't use the usual jax debug_nans flag as some of data naturally has Nans present.
Presently I am creating a boolean pytree that checks for any nans on the leaves, and ideally I would be able to print the actually boolean value as opposed to the usual Traced<ShapedArray(bool[])>
, which would be done using the jax debug print.
Extending your error_if
example:
@eqx.filter_jit
def f(x):
bool_tree = jax.tree_map(lambda x: np.isnan(x).any(), x)
vals = np.array(jax.tree_util.tree_flatten(bool_tree)[0])
msg = "Nan found in tree:\n" + eqx.tree_pformat(bool_tree, short_arrays=False)
x = eqx.error_if(x, vals.sum() > 0, msg)
return x
pytree = (np.zeros(3), np.zeros(5))
_ = f(pytree)
nan_pytree = eqx.tree_at(lambda pytree: pytree[0], pytree, np.zeros(3).at[2].set(np.nan))
_ = f(nan_pytree)
However this doesn't print the actual array values, hence the desire for some interface with jax.debug.print
. Is there some simple way to achieve this?
Thanks in advance!