jax
jax copied to clipboard
[pallas] Improve some error messages and add API tests.
We make the following improvements:
- pytree structural disequality messages attempt to localize the mismatch
- we check that the rank of the block_shape matches the rank of
the overall array. Without this we used to get a
safe_zip
error. We also carry the pytree paths to localize the error. - We check that the kernel function returns None. Without this
we used to get
body_fun output and input must have same type structure
in the interpreter,assert len(jaxpr.outvars) == 0
on GPU, andINTERNAL: Mosaic failed to compile TPU kernel: has 1 operands, but enclosing function (@main) returns 0
on TPU.
To simplify the generation of the error messages we added a helper
function tree_util.equality_errors_pytreedef
, which is just like
tree_util.equality_errors
but takes PyTreeDef
inputs rather than
PyTrees. We then used this new helper function in pjit.py
and stages.py
.