jax
jax copied to clipboard
vmap using in_axes doesn't handle named arguments
Passing keyword arguments to a function vectorized with jax.vmap(in_axes=..., out_axes=...)
does not seem to work and results in an AssertionError
.
For example:
import jax
import jax.numpy as jnp
def f(a, b, c):
return (2*a, 3*b + c)
print(jax.vmap(f, in_axes=(0, 0, None), out_axes=0)(jnp.array([1, 2]), jnp.array([2, 4]), 0.5)) # works
# (DeviceArray([2, 4], dtype=int32), DeviceArray([ 6.5, 12.5], dtype=float32))
print(jax.vmap(f, in_axes=(0, 0, None), out_axes=0)(a=jnp.array([1, 2]), b=jnp.array([2, 4]), c=0.5)) # doesn't work
results in:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
[... skipping hidden 1 frame]
~/miniconda3/envs/lcms/lib/python3.9/site-packages/jax/_src/tree_util.py in tree_map(f, tree, is_leaf, *rest)
166 leaves, treedef = tree_flatten(tree, is_leaf)
--> 167 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
168 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
~/miniconda3/envs/lcms/lib/python3.9/site-packages/jax/_src/tree_util.py in <listcomp>(.0)
166 leaves, treedef = tree_flatten(tree, is_leaf)
--> 167 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
168 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
ValueError: Tuple arity mismatch: 0 != 3; tuple: ().
During handling of the above exception, another exception occurred:
AssertionError Traceback (most recent call last)
/tmp/ipykernel_218392/621472192.py in <module>
8 # (DeviceArray([2, 4], dtype=int32), DeviceArray([ 6.5, 12.5], dtype=float32))
9
---> 10 print(jax.vmap(f, in_axes=(0, 0, None), out_axes=0)(a=jnp.array([1, 2]), b=jnp.array([2, 4]), c=0.5)) # doesn't work
[... skipping hidden 2 frame]
~/miniconda3/envs/lcms/lib/python3.9/site-packages/jax/api_util.py in flatten_axes(name, treedef, axis_tree, kws)
274 # message only to be about the positional arguments
275 treedef, leaf = treedef_children(treedef)
--> 276 assert treedef_is_leaf(leaf)
277 axis_tree, _ = axis_tree
278 raise ValueError(f"{name} specification must be a tree prefix of the "
AssertionError:
Version used: jax: 0.2.18
The issue seems to be specifically related to the jax.vmap
in_axes
and out_axes
arguments.
Note that named arguments work when in_axes
and out_axes
are not needed (and not used), as was fixed in https://github.com/google/jax/pull/5387 and closed in https://github.com/google/jax/issues/912.
E.g. The following works:
import jax
import jax.numpy as jnp
def f(a, b):
return (2*a, 3*b)
print(jax.vmap(f)(jnp.array([1, 2]), jnp.array([2, 4])))
print(jax.vmap(f)(a=jnp.array([1, 2]), b=jnp.array([2, 4])))
However, as soon as I need to rely on in_axes
to avoid a certain parameter being mapped, keyword arguments start raising the AssertionError
.
The bad thing here, IMHO, is not that keyword arguments are not supported. I can see that it would complicate the signature of vmap.
The problem is that it's just an assertion error which is hard to understand.
Has there been any update to this? I can confirm the issue still exists, see below:
Vmapping a function with a keyword argument fails when using in_axes.
Minimal example:
test_fn = lambda x, y, z=None: x
vmapped_fn = jax.vmap(test_fn, in_axes=(0, None), out_axes=0)
vmapped_fn(jnp.ones(shape=(10,1)), 1, z=1)
Results in:
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
[... skipping hidden 1 frame]
IndexError: tuple index out of range
The above exception was the direct cause of the following exception:
ValueError Traceback (most recent call last)
<ipython-input-33-7a07ef29e97f> in <module>
1 test_fn = lambda x, y, z=None: x
2 vmapped_fn = jax.vmap(test_fn, in_axes=(0, None), out_axes=0)
----> 3 vmapped_fn(jnp.ones(shape=(10,1)), 1, z=1)
[... skipping hidden 5 frame]
ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())
Yet this works, even though we aren't mapping the named argument:
vmapped_fn(jnp.ones(shape=(10)), 1, z=jnp.ones(10))
_get_axis_size seems to be checking the dimensionality of named arguments.
tested with jax==0.2.25 and jaxlib==0.1.74
I was hit by this today. Can we have an explicit error while the behavior is not fixed perhaps?
Got the same problem today. Spent half a day trying to understand what's wrong with the code until I've found this thread and made a helper function that takes care of the named argument inside it.
Would be cool to at least change the message.
Got the same problem today. The empty error message makes it very difficult to understand, by chance I had a snippet that did not contain named arguments so I could compare my code, otherwise it would have been very difficult.
Are there any plans on adding support for named arguments for vmap
when using in_axes
and/or out_axes
? Right now, in one way or another, I have to use a unideal (or hacky) solution to overcome this limitation:
import jax
import jax.numpy as jnp
def foo(x, w, b):
return jnp.vdot(x, w) + b
foo_vmap = jax.vmap(foo, (0, None, None), 0)
X = jnp.arange(5 * 3).reshape(3, 5)
parameters = {"w": jnp.arange(5), "b": 1}
# this breaks!!!
foo_vmap(X, **parameters)
# works, but clutters the code especially when many parameters are involved, it
# really defeats the purpose of using dictionaries in the first place
foo_vmap(X, parameters["w"], parameters["b"])
# works, but adds another layer of higher level function transformations, that makes
# it harder to understand the code from first glance (and kinda defeats the functional
# paradigm given that the parameters become implicit inputs)
import functools
foo_vmap_partial = jax.vmap(functools.partial(foo, **parameters), (0), 0)
foo_vmap_partial(X)
# works, but we need to recreate the function signature, which adds bloat
foo_vmap_lambda = lambda x, w, b: foo_vmap(x, w, b)
print(foo_vmap_lambda(X, **parameters))
As a side note, would be really cool to be able to input in_axes
as a dictionary mapping each keyword arg to axis to map (and if a keyword arg is missing from the dictionary, we assume a value of None
):
jax.vmap(foo, {"x": 0}, 0)(X, **parameters)
Definitely agree. I sometimes wrap jax.vmap
with something like this:
def named_vmap(f, axes_names, **kwargs):
in_axes = ({k:0 if k in axes_names else None for k in kwargs.keys()},)
return jax.vmap(lambda input_dict: f(**input_dict), in_axes=in_axes)(kwargs)
This one doesn't work with nested dicts (though it should be straightforward to extend it). Would be nice to have something like this as default behaviour!
I am joining in the conversation with another minimal example:
import jax
def f(a, *, b):
return a @ b
a = jnp.ones((2, 4))
b = jnp.ones((4, 2))
print(f(a, b=b_batched).shape) # (2, 2)
f_batched = jax.vmap(f, in_axes=(None, 0))
b_batched = jnp.ones((10, 4, 2))
print(f_batched(a, b=b_batched).shape) # AssertionError
Perhaps, in the lack of user control, it might be safe to assume at least in_axes=None
for named arguments?