jax icon indicating copy to clipboard operation
jax copied to clipboard

[pallas] Improve some error messages and add API tests.

Open gnecula opened this issue 7 months ago • 0 comments

We make the following improvements:

  • pytree structural disequality messages attempt to localize the mismatch
  • we check that the rank of the block_shape matches the rank of the overall array. Without this we used to get a safe_zip error. We also carry the pytree paths to localize the error.
  • We check that the kernel function returns None. Without this we used to get body_fun output and input must have same type structure in the interpreter, assert len(jaxpr.outvars) == 0 on GPU, and INTERNAL: Mosaic failed to compile TPU kernel: has 1 operands, but enclosing function (@main) returns 0 on TPU.

To simplify the generation of the error messages we added a helper function tree_util.equality_errors_pytreedef, which is just like tree_util.equality_errors but takes PyTreeDef inputs rather than PyTrees. We then used this new helper function in pjit.py and stages.py.

gnecula avatar Jun 28 '24 12:06 gnecula