jax icon indicating copy to clipboard operation
jax copied to clipboard

Issue with static fields in `in_axes` pytree structure in `jax.vmap`

Open bheijden opened this issue 1 year ago • 4 comments

Description

Hi,

I am encountering an issue with the in_axes parameter when using jax.vmap. Specifically, I am passing in_axes as a pytree that mirrors the structure of the arguments for the function I am applying vmap to. However, due to certain application-specific requirements, the static fields (i.e., data elements that are bypassed by tree_map) within both pytrees do not match. This discrepancy leads to an error indicating a mismatch in the structures of the two pytrees, despite their relevant structures being identical. The error is resolved when I ensure that all static fields are identical across both pytrees.

It appears that the root cause of this issue might be related to the treedefs.node_data() for in_axes and the corresponding values not matching.

Below a MWE:

import jax
import jax.numpy as jnp
from flax import struct

@struct.dataclass
class Bar:
    b: jax.typing.ArrayLike
    static_field: object = struct.field(pytree_node=False, default=None)


@struct.dataclass
class Foo:
    a: jax.typing.ArrayLike
    bar: Bar
    static_field: int = struct.field(pytree_node=False, default=None)

val = Foo(jnp.array([1]), Bar(jnp.array([1]), static_field=1), static_field=2)
in_axes = jax.tree_util.tree_map(lambda x: 0, val)
in_axes = in_axes.replace(static_field=99)   # Uncommment and error disappears.

jax.vmap(lambda x: x, in_axes=(in_axes,))(val)

Results in:

ValueError: vmap in_axes specification must be a tree prefix of the corresponding value, got specification (Foo(a=0, bar=Bar(b=0, static_field=1), static_field=99),) for value tree PyTreeDef((CustomNode(Foo[(2,)], [*, CustomNode(Bar[(1,)], [*])]),)).

What jax/jaxlib version are you using?

jax v0.4.23, jaxlib v0.4.23+cuda12.cudnn89

Which accelerator(s) are you using?

CPU/GPU

Additional system info?

Ubuntu 20.04

NVIDIA GPU info

| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 3070 ...    Off | 00000000:01:00.0  On |                  N/A |
| N/A   48C    P8              24W / 125W |   5178MiB /  8192MiB |      2%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A      1510      G   /usr/lib/xorg/Xorg                          139MiB |
|    0   N/A  N/A      2027      G   /usr/lib/xorg/Xorg                          305MiB |
|    0   N/A  N/A      2160      G   /usr/bin/gnome-shell                         70MiB |
|    0   N/A  N/A      5702      G   ...rdCache,DocumentPictureInPictureAPI       41MiB |
|    0   N/A  N/A      5704      G   ...83534335,6158250049560409502,131072       64MiB |
|    0   N/A  N/A      7332      G   /usr/lib/firefox/firefox                    243MiB |
|    0   N/A  N/A     22124      C   ...s/rex-lib-fAzIlxw_-py3.9/bin/python      138MiB |
|    0   N/A  N/A     23265      C   ...s/rex-lib-fAzIlxw_-py3.9/bin/python     4136MiB |

bheijden avatar Feb 09 '24 10:02 bheijden

Hi - this is more-or-less working as expected, because pytree equivalence includes equivalence of static elements.

Can you say more about why it's not possible to specify in_axes with an equivalent pytree to the input?

jakevdp avatar Feb 09 '24 16:02 jakevdp

In our application, the construction of the in_axes and val pytrees occurs at distinct stages due to the nature of the codebase. Initially, when the in_axes pytree is assembled, some static data necessary for its completion is not yet available, leading us to fill these fields with None as placeholders. Only at a later stage, when we construct the val pytree, do we have access to all the requisite static data, allowing us to populate it accordingly. This sequential process results in a mismatch between the static elements of the in_axes and val pytrees.

Adding to the complexity, the structure of the in_axes pytree does not mirror the val pytree exactly. Specifically, the in_axes pytree includes None leaves in places where the val pytree contains pytree nodes. This setup is designed so that the None values in in_axes correspond to all leaves under those nodes in the val pytree.

I did not anticipate the inclusion of static data as a factor for pytree equivalence, as my expectation was that this only concerned the "non-static part". Moreover, this equivalence is not strictly enforced, since None values for leafs where the val pytree has a deeper pytree structure is generally allowed. This expectation was partly due to my interpretation of the documentation on the in_axes parameter, which did not explicitly state the necessity for identical static data.

bheijden avatar Feb 12 '24 08:02 bheijden

I don't entirely follow – if None can be used in place of full subtrees, how could you possibly expect a generic in_axis to match a runtime pytree that may have more leaves than the specification?

Overall, my recommendation would be to specify in_axes at a point where you actually have the data, and therefore know what the in_axes specification should be.

jakevdp avatar Feb 12 '24 17:02 jakevdp

Using None or any int axis specification as leaves in in_axes where val contains full subtrees appears to function correctly (refer to the provided minimal working example). Looking at the internals, it seems this functionality is supported through this line, where the (partially incomplete) in_axes pytree is extended with the missing subtrees from val. This is also where the static data error comes from. This working is also discussed here.

Initially, it was not apparent that in_axes and val pytrees should mirror each other structurally, especially since "prefix" pytrees can define the in_axes parameter for whole subtrees. The error message, lacking details about the necessity for static data alignment, were confusing, leading to some time spent troubleshooting. While I still don't really understand why the static data of in_axes should exactly match that of val, a more informative error message that states this would already be helpful!

In the example below I replace in_axes.bar=None such that in_axes has a None value in a place with val has a subtree. Internally, this line extends the None value to all leafs of the subtree.

import jax
import jax.numpy as jnp
from flax import struct


@struct.dataclass
class Bar:
    a: jax.typing.ArrayLike
    b: jax.typing.ArrayLike
    static_field: object = struct.field(pytree_node=False, default=None)


@struct.dataclass
class Foo:
    a: jax.typing.ArrayLike
    bar: Bar
    static_field: int = struct.field(pytree_node=False, default=None)


val = Foo(jnp.array([1]), Bar(jnp.array([1]), jnp.array([1]), static_field=1), static_field=2)
in_axes = jax.tree_util.tree_map(lambda x: 0, val)
in_axes = in_axes.replace(bar=None)   # prefix pytree

jax.vmap(lambda x: x, in_axes=(in_axes,))(val)

bheijden avatar Feb 13 '24 08:02 bheijden