jax icon indicating copy to clipboard operation
jax copied to clipboard

vmap using in_axes doesn't handle named arguments

Open peterroelants opened this issue 3 years ago • 9 comments

Passing keyword arguments to a function vectorized with jax.vmap(in_axes=..., out_axes=...) does not seem to work and results in an AssertionError.

For example:

import jax
import jax.numpy as jnp

def f(a, b, c):
    return (2*a, 3*b + c)

print(jax.vmap(f, in_axes=(0, 0, None), out_axes=0)(jnp.array([1, 2]), jnp.array([2, 4]), 0.5))  # works
# (DeviceArray([2, 4], dtype=int32), DeviceArray([ 6.5, 12.5], dtype=float32))

print(jax.vmap(f, in_axes=(0, 0, None), out_axes=0)(a=jnp.array([1, 2]), b=jnp.array([2, 4]), c=0.5))  # doesn't work

results in:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
    [... skipping hidden 1 frame]

~/miniconda3/envs/lcms/lib/python3.9/site-packages/jax/_src/tree_util.py in tree_map(f, tree, is_leaf, *rest)
    166   leaves, treedef = tree_flatten(tree, is_leaf)
--> 167   all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
    168   return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

~/miniconda3/envs/lcms/lib/python3.9/site-packages/jax/_src/tree_util.py in <listcomp>(.0)
    166   leaves, treedef = tree_flatten(tree, is_leaf)
--> 167   all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
    168   return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

ValueError: Tuple arity mismatch: 0 != 3; tuple: ().

During handling of the above exception, another exception occurred:

AssertionError                            Traceback (most recent call last)
/tmp/ipykernel_218392/621472192.py in <module>
      8 # (DeviceArray([2, 4], dtype=int32), DeviceArray([ 6.5, 12.5], dtype=float32))
      9 
---> 10 print(jax.vmap(f, in_axes=(0, 0, None), out_axes=0)(a=jnp.array([1, 2]), b=jnp.array([2, 4]), c=0.5))  # doesn't work

    [... skipping hidden 2 frame]

~/miniconda3/envs/lcms/lib/python3.9/site-packages/jax/api_util.py in flatten_axes(name, treedef, axis_tree, kws)
    274       # message only to be about the positional arguments
    275       treedef, leaf = treedef_children(treedef)
--> 276       assert treedef_is_leaf(leaf)
    277       axis_tree, _ = axis_tree
    278     raise ValueError(f"{name} specification must be a tree prefix of the "

AssertionError: 

Version used: jax: 0.2.18

peterroelants avatar Aug 03 '21 08:08 peterroelants

The issue seems to be specifically related to the jax.vmap in_axes and out_axes arguments.

Note that named arguments work when in_axes and out_axes are not needed (and not used), as was fixed in https://github.com/google/jax/pull/5387 and closed in https://github.com/google/jax/issues/912. E.g. The following works:

import jax
import jax.numpy as jnp

def f(a, b):
    return (2*a, 3*b)

print(jax.vmap(f)(jnp.array([1, 2]), jnp.array([2, 4])))
print(jax.vmap(f)(a=jnp.array([1, 2]), b=jnp.array([2, 4])))

However, as soon as I need to rely on in_axes to avoid a certain parameter being mapped, keyword arguments start raising the AssertionError.

peterroelants avatar Aug 03 '21 08:08 peterroelants

The bad thing here, IMHO, is not that keyword arguments are not supported. I can see that it would complicate the signature of vmap.

The problem is that it's just an assertion error which is hard to understand.

bayerj avatar Sep 17 '21 09:09 bayerj

Has there been any update to this? I can confirm the issue still exists, see below:

Vmapping a function with a keyword argument fails when using in_axes.

Minimal example:

test_fn = lambda x, y, z=None: x
vmapped_fn = jax.vmap(test_fn, in_axes=(0, None), out_axes=0)
vmapped_fn(jnp.ones(shape=(10,1)), 1, z=1)

Results in:

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
    [... skipping hidden 1 frame]

IndexError: tuple index out of range

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
<ipython-input-33-7a07ef29e97f> in <module>
      1 test_fn = lambda x, y, z=None: x
      2 vmapped_fn = jax.vmap(test_fn, in_axes=(0, None), out_axes=0)
----> 3 vmapped_fn(jnp.ones(shape=(10,1)), 1, z=1)

    [... skipping hidden 5 frame]

ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())

Yet this works, even though we aren't mapping the named argument:

vmapped_fn(jnp.ones(shape=(10)), 1, z=jnp.ones(10))

_get_axis_size seems to be checking the dimensionality of named arguments.

tested with jax==0.2.25 and jaxlib==0.1.74

UmaisZahid avatar Nov 29 '21 16:11 UmaisZahid

I was hit by this today. Can we have an explicit error while the behavior is not fixed perhaps?

ricardoV94 avatar Jan 11 '22 14:01 ricardoV94

Got the same problem today. Spent half a day trying to understand what's wrong with the code until I've found this thread and made a helper function that takes care of the named argument inside it.

Would be cool to at least change the message.

che-shr-cat avatar Jan 19 '22 18:01 che-shr-cat

Got the same problem today. The empty error message makes it very difficult to understand, by chance I had a snippet that did not contain named arguments so I could compare my code, otherwise it would have been very difficult.

valentinmace avatar Jul 04 '22 09:07 valentinmace

Are there any plans on adding support for named arguments for vmap when using in_axes and/or out_axes? Right now, in one way or another, I have to use a unideal (or hacky) solution to overcome this limitation:

import jax
import jax.numpy as jnp


def foo(x, w, b):
    return jnp.vdot(x, w) + b


foo_vmap = jax.vmap(foo, (0, None, None), 0)

X = jnp.arange(5 * 3).reshape(3, 5)
parameters = {"w": jnp.arange(5), "b": 1}


# this breaks!!!
foo_vmap(X, **parameters)

# works, but clutters the code especially when many parameters are involved, it
# really defeats the purpose of using dictionaries in the first place
foo_vmap(X, parameters["w"], parameters["b"])

# works, but adds another layer of higher level function transformations, that makes
# it harder to understand the code from first glance (and kinda defeats the functional
# paradigm given that the parameters become implicit inputs)
import functools
foo_vmap_partial = jax.vmap(functools.partial(foo, **parameters), (0), 0)
foo_vmap_partial(X)

# works, but we need to recreate the function signature, which adds bloat
foo_vmap_lambda = lambda x, w, b: foo_vmap(x, w, b)
print(foo_vmap_lambda(X, **parameters))

As a side note, would be really cool to be able to input in_axes as a dictionary mapping each keyword arg to axis to map (and if a keyword arg is missing from the dictionary, we assume a value of None):

jax.vmap(foo, {"x": 0}, 0)(X, **parameters)

jaymody avatar Sep 21 '22 19:09 jaymody

Definitely agree. I sometimes wrap jax.vmap with something like this:

def named_vmap(f, axes_names, **kwargs):    

    in_axes = ({k:0 if k in axes_names else None for k in kwargs.keys()},)
    return jax.vmap(lambda input_dict: f(**input_dict), in_axes=in_axes)(kwargs)

This one doesn't work with nested dicts (though it should be straightforward to extend it). Would be nice to have something like this as default behaviour!

mchagneux avatar Nov 28 '22 10:11 mchagneux

I am joining in the conversation with another minimal example:

import jax


def f(a, *, b):
    return a @ b


a = jnp.ones((2, 4))
b = jnp.ones((4, 2))
print(f(a, b=b_batched).shape)  # (2, 2)

f_batched = jax.vmap(f, in_axes=(None, 0))
b_batched = jnp.ones((10, 4, 2))
print(f_batched(a, b=b_batched).shape)  # AssertionError

Perhaps, in the lack of user control, it might be safe to assume at least in_axes=None for named arguments?

epignatelli avatar Jan 03 '23 23:01 epignatelli