jax icon indicating copy to clipboard operation
jax copied to clipboard

Improve error messages for pmap

Open marcvanzee opened this issue 2 years ago • 3 comments

(Not: this was originally a request from @patrickvonplaten after the Flax hackathon):

Suppose you have the following code:

import jax.numpy as jnp
import jax

def test_fn(params):
  return jnp.mean(params)

jax.pmap(test_fn)({'x': jnp.arange(1), 'y': jnp.arange(2)})

This give the error:

ValueError: pmap got inconsistent sizes for array axes to be mapped:
the tree of axis sizes is:
({'x': 1, 'y': 2},)

This doesn't seem like a great error message. For a unexperienced user it's not obvious that the parameter dict has to be replicated on all devices before being passed to pmap -> the error message doesn't mention any mismatch between "number of devices" and the input. I think somewhere in the error message it the expected size should be mentioned.

marcvanzee avatar Mar 17 '22 13:03 marcvanzee

Man, you saved my day :) I am that one "unexperienced user"

In my forked code there wes decorator @partial(jax.pmap, axis_name="batch") that brought me to this problem

SolomidHero avatar Oct 18 '22 23:10 SolomidHero

Can you give the solution, thx!

meiling-fdu avatar Nov 07 '23 08:11 meiling-fdu

Hi @marcvanzee

I executed the mentioned code on Colab using JAX version 0.4.23. The error message reads as follows:

ValueError: pmap got inconsistent sizes for array axes to be mapped:
  * one axis had size 1: axis 0 of argument params['x'] of type int32[1];
  * one axis had size 2: axis 0 of argument params['y'] of type int32[2]

Kindly find the gist for reference.

Thank you

rajasekharporeddy avatar Feb 22 '24 16:02 rajasekharporeddy