jax
jax copied to clipboard
Improve error messages for pmap
(Not: this was originally a request from @patrickvonplaten after the Flax hackathon):
Suppose you have the following code:
import jax.numpy as jnp
import jax
def test_fn(params):
return jnp.mean(params)
jax.pmap(test_fn)({'x': jnp.arange(1), 'y': jnp.arange(2)})
This give the error:
ValueError: pmap got inconsistent sizes for array axes to be mapped:
the tree of axis sizes is:
({'x': 1, 'y': 2},)
This doesn't seem like a great error message. For a unexperienced user it's not obvious that the parameter dict has to be replicated on all devices before being passed to pmap -> the error message doesn't mention any mismatch between "number of devices" and the input. I think somewhere in the error message it the expected size should be mentioned.
Man, you saved my day :) I am that one "unexperienced user"
In my forked code there wes decorator @partial(jax.pmap, axis_name="batch")
that brought me to this problem
Can you give the solution, thx!
Hi @marcvanzee
I executed the mentioned code on Colab using JAX version 0.4.23. The error message reads as follows:
ValueError: pmap got inconsistent sizes for array axes to be mapped:
* one axis had size 1: axis 0 of argument params['x'] of type int32[1];
* one axis had size 2: axis 0 of argument params['y'] of type int32[2]
Kindly find the gist for reference.
Thank you