jax
jax copied to clipboard
Issue with static fields in `in_axes` pytree structure in `jax.vmap`
Description
Hi,
I am encountering an issue with the in_axes
parameter when using jax.vmap
. Specifically, I am passing in_axes
as a pytree that mirrors the structure of the arguments for the function I am applying vmap
to. However, due to certain application-specific requirements, the static fields (i.e., data elements that are bypassed by tree_map
) within both pytrees do not match. This discrepancy leads to an error indicating a mismatch in the structures of the two pytrees, despite their relevant structures being identical. The error is resolved when I ensure that all static fields are identical across both pytrees.
It appears that the root cause of this issue might be related to the treedefs.node_data()
for in_axes
and the corresponding values not matching.
Below a MWE:
import jax
import jax.numpy as jnp
from flax import struct
@struct.dataclass
class Bar:
b: jax.typing.ArrayLike
static_field: object = struct.field(pytree_node=False, default=None)
@struct.dataclass
class Foo:
a: jax.typing.ArrayLike
bar: Bar
static_field: int = struct.field(pytree_node=False, default=None)
val = Foo(jnp.array([1]), Bar(jnp.array([1]), static_field=1), static_field=2)
in_axes = jax.tree_util.tree_map(lambda x: 0, val)
in_axes = in_axes.replace(static_field=99) # Uncommment and error disappears.
jax.vmap(lambda x: x, in_axes=(in_axes,))(val)
Results in:
ValueError: vmap in_axes specification must be a tree prefix of the corresponding value, got specification (Foo(a=0, bar=Bar(b=0, static_field=1), static_field=99),) for value tree PyTreeDef((CustomNode(Foo[(2,)], [*, CustomNode(Bar[(1,)], [*])]),)).
What jax/jaxlib version are you using?
jax v0.4.23, jaxlib v0.4.23+cuda12.cudnn89
Which accelerator(s) are you using?
CPU/GPU
Additional system info?
Ubuntu 20.04
NVIDIA GPU info
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA GeForce RTX 3070 ... Off | 00000000:01:00.0 On | N/A |
| N/A 48C P8 24W / 125W | 5178MiB / 8192MiB | 2% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| 0 N/A N/A 1510 G /usr/lib/xorg/Xorg 139MiB |
| 0 N/A N/A 2027 G /usr/lib/xorg/Xorg 305MiB |
| 0 N/A N/A 2160 G /usr/bin/gnome-shell 70MiB |
| 0 N/A N/A 5702 G ...rdCache,DocumentPictureInPictureAPI 41MiB |
| 0 N/A N/A 5704 G ...83534335,6158250049560409502,131072 64MiB |
| 0 N/A N/A 7332 G /usr/lib/firefox/firefox 243MiB |
| 0 N/A N/A 22124 C ...s/rex-lib-fAzIlxw_-py3.9/bin/python 138MiB |
| 0 N/A N/A 23265 C ...s/rex-lib-fAzIlxw_-py3.9/bin/python 4136MiB |
Hi - this is more-or-less working as expected, because pytree equivalence includes equivalence of static elements.
Can you say more about why it's not possible to specify in_axes
with an equivalent pytree to the input?
In our application, the construction of the in_axes
and val
pytrees occurs at distinct stages due to the nature of the codebase. Initially, when the in_axes
pytree is assembled, some static data necessary for its completion is not yet available, leading us to fill these fields with None
as placeholders. Only at a later stage, when we construct the val
pytree, do we have access to all the requisite static data, allowing us to populate it accordingly. This sequential process results in a mismatch between the static elements of the in_axes
and val
pytrees.
Adding to the complexity, the structure of the in_axes
pytree does not mirror the val
pytree exactly. Specifically, the in_axes
pytree includes None
leaves in places where the val
pytree contains pytree nodes. This setup is designed so
that the None
values in in_axes
correspond to all leaves under those nodes in the val
pytree.
I did not anticipate the inclusion of static data as a factor for pytree equivalence, as my expectation was that this only concerned the "non-static part". Moreover, this equivalence is not strictly enforced, since None
values for leafs where the val
pytree has a deeper pytree structure is generally allowed. This expectation was partly due to my interpretation of the documentation on the in_axes
parameter, which did not explicitly state the necessity for identical static data.
I don't entirely follow – if None
can be used in place of full subtrees, how could you possibly expect a generic in_axis
to match a runtime pytree that may have more leaves than the specification?
Overall, my recommendation would be to specify in_axes
at a point where you actually have the data, and therefore know what the in_axes
specification should be.
Using None
or any int
axis specification as leaves in in_axes
where val
contains full subtrees appears to function correctly (refer to the provided minimal working example). Looking at the internals, it seems this functionality is supported through this line, where the (partially incomplete) in_axes
pytree is extended with the missing subtrees from val
. This is also where the static data error comes from. This working is also discussed here.
Initially, it was not apparent that in_axes
and val
pytrees should mirror each other structurally, especially since "prefix" pytrees can define the in_axes
parameter for whole subtrees. The error message, lacking details about the necessity for static data alignment, were confusing, leading to some time spent troubleshooting. While I still don't really understand why the static data of in_axes
should exactly match that of val
, a more informative error message that states this would already be helpful!
In the example below I replace in_axes.bar=None
such that in_axes
has a None value in a place with val
has a subtree. Internally, this line extends the None value to all leafs of the subtree.
import jax
import jax.numpy as jnp
from flax import struct
@struct.dataclass
class Bar:
a: jax.typing.ArrayLike
b: jax.typing.ArrayLike
static_field: object = struct.field(pytree_node=False, default=None)
@struct.dataclass
class Foo:
a: jax.typing.ArrayLike
bar: Bar
static_field: int = struct.field(pytree_node=False, default=None)
val = Foo(jnp.array([1]), Bar(jnp.array([1]), jnp.array([1]), static_field=1), static_field=2)
in_axes = jax.tree_util.tree_map(lambda x: 0, val)
in_axes = in_axes.replace(bar=None) # prefix pytree
jax.vmap(lambda x: x, in_axes=(in_axes,))(val)