flax icon indicating copy to clipboard operation
flax copied to clipboard

nnx.Optimizer doesn't respect extra sharding axis added by nnx.scan/nnx.vmap

Open qGentry opened this issue 1 month ago • 4 comments

Hi folks, me again.

I keep playing around with nnx and seems like nnx.Optimizer, when creating optimizer state for models intended to be used with 'scan', use sharding information for original, non-stacked tensor, without taking into the account extra dimension added by vmap.

Repro script:

import jax
import flax.nnx as nnx
import jax
import jax.numpy as jnp
import flax.nnx as nnx
import optax


mesh1 = jax.make_mesh((2, 4), ("a", "b"))
rules1 = (("A", "a"), ("B", "b"))


class Model(nnx.Module):
    def __init__(self, num_layers, rngs: nnx.Rngs):
        @nnx.split_rngs(splits=num_layers)
        @nnx.vmap(in_axes=(0,), out_axes=0)
        def create_linear(rngs: nnx.Rngs):
            return nnx.Param(
                jnp.ones((16, 16)), 
                sharding=("A", "B"), 
                mesh=mesh1,
                sharding_rules=rules1,
            )
        self.linears = create_linear(rngs=rngs)


@nnx.jit
def init():
    model = Model(num_layers=1, rngs=nnx.Rngs(params=0))
    optimizer = nnx.Optimizer(
        model,
        optax.adam(learning_rate=0.001),
        wrt=nnx.Param,
    )
    return model, optimizer

model, optimizer = init()

Output:

Traceback (most recent call last):
  File "/papyrax/test_scan_axis.py", line 37, in <module>
    model, optimizer = init()
                       ^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/flax/nnx/transforms/compilation.py", line 474, in __call__
    pure_args_out, pure_kwargs_out, pure_out = self.jitted_fn(
                                               ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/flax/nnx/transforms/compilation.py", line 135, in __call__
    out = self.f(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^
  File "/papyrax/test_scan_axis.py", line 30, in init
    optimizer = nnx.Optimizer(
                ^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/flax/nnx/pytreelib.py", line 400, in __call__
    return _graph_node_meta_call(cls, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/flax/nnx/pytreelib.py", line 412, in _graph_node_meta_call
    cls._pytree_meta_construct(node, *args, **kwargs)
  File "/usr/local/lib/python3.11/dist-packages/flax/nnx/pytreelib.py", line 403, in _pytree_meta_construct
    self.__init__(*args, **kwargs)
  File "/usr/local/lib/python3.11/dist-packages/flax/nnx/training/optimizer.py", line 88, in _check_wrt_wrapper
    return f(*args, wrt=wrt, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/flax/nnx/training/optimizer.py", line 160, in __init__
    to_opt_state(tx.init(nnx.state(model, wrt)))
  File "/usr/local/lib/python3.11/dist-packages/flax/nnx/training/optimizer.py", line 57, in to_opt_state
    tree = jax.tree.map(
           ^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/jax/_src/tree.py", line 155, in map
    return tree_util.tree_map(f, tree, *rest, is_leaf=is_leaf)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/flax/nnx/training/optimizer.py", line 52, in _to_opt_state
    opt_state = OptVariable(x.get_value(), **x.get_metadata())  # type: ignore
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/flax/nnx/variablelib.py", line 904, in __call__
    return cls._variable_meta_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/flax/nnx/variablelib.py", line 907, in _variable_meta_call
    variable = super().__call__(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/flax/nnx/variablelib.py", line 1108, in __init__
    value = core_spmd.shard_value(
            ^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/flax/core/spmd.py", line 49, in shard_value
    return _apply_sharding(value, NamedSharding(mesh, pspec))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/flax/core/spmd.py", line 37, in _apply_sharding
    return jax.jit(lambda x: x, out_shardings=sharding)(value)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: One of pjit outputs with pytree key path result was given the sharding of NamedSharding(mesh=Mesh('a': 2, 'b': 4, axis_types=(Auto, Auto)), spec=PartitionSpec('a', 'b'), memory_kind=device), which implies that the global size of its dimension 0 should be divisible by 2, but it is equal to 1 (full shape: (1, 16, 16))
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

qGentry avatar Nov 24 '25 16:11 qGentry

The issue seems to be with the following function in flax:

https://github.com/google/flax/blob/697f4e5cda4b110decf862a6f8ab71a0345d0412/flax/nnx/training/optimizer.py#L49

It tries to apply sharding of the model's tensor which contains sharding metadata for original, non-vmapped Variable.

Generally, I think that for vmapped arrays, we should automatically add additional axis to its sharding metadata, shouldn't we? Because rn we have this discrepancy between sharding metadata VariableMetadata and actual sharding of the array which lead to issue like this one.

qGentry avatar Nov 24 '25 16:11 qGentry

Thanks for stress testing flax in these cases, @qGentry ! Let me see what happens inside.

vfdev-5 avatar Nov 24 '25 16:11 vfdev-5

Seems like extending nnx.vmap args with transform_metadata solves this issue.

        @nnx.vmap(
            in_axes=(0,), 
            out_axes=0,
            transform_metadata={
                nnx.spmd.PARTITION_NAME: None,
            }
        )

I wonder if it should happen automatically or mentioned in documentation here https://flax.readthedocs.io/en/latest/nnx_basics.html#scan-over-layers

qGentry avatar Nov 24 '25 17:11 qGentry

I was able to reproduce the issue, and I think the problem comes from how sharding metadata is handled when NNX transforms add new axes. I think vmap, scan, and pmap should always insert a default sharding metadata entry ({PARTITION_NAME: None}) so the extra axis they introduce is explicitly marked as unsharded.

Also, probably the sharding add/remove helper functions should be idempotent and tolerant of missing axes, so sharding_names stays aligned with the actual array rank after a lift. So, when nnx.Optimizer rewraps parameters and reapplies eager sharding; it would no longer try to shard the new leading axis, which should avoid the dimension-0 divisibility error.

@vfdev-5 Any thoughts or suggestions?

mohsinm-dev avatar Dec 12 '25 02:12 mohsinm-dev

To summarize what's been said so far, the core issue is that when we nnx.vmap over a function that returns a Param, we currently get back a Param that has the same sharding metadata rather than one with an extra None in it. You have to explicitly add the nnx.PARTITION_NAME key in transform_metadata if you want the returned Param to have the proper sharding. This is already documented in the nnx guides https://flax.readthedocs.io/en/stable/guides/transforms.html#axis-metadata. But it is pretty unintuitive.

One way to get around this is to use Jax's explicit sharding instead:

mesh = jax.make_mesh((2, 4), ("a", "b"), axis_types=(AxisType.Explicit, AxisType.Explicit))
jax.set_mesh(mesh)

class Model(nnx.Module):
    def __init__(self, num_layers, rngs: nnx.Rngs):
        @nnx.split_rngs(splits=num_layers)
        @nnx.vmap(in_axes=(0,), out_axes=0)
        def create_linear(rngs: nnx.Rngs):
            return nnx.Param(
                jnp.ones((16, 16), out_sharding=P("a", "b"))
            )
        self.linears = create_linear(rngs=rngs)

This works just fine with the Optimizer case above.

samanklesaria avatar Dec 15 '25 18:12 samanklesaria

@samanklesaria Good point on the metadata gap. I’ve put together a fix that makes vmap/scan/pmap insert a default {PARTITION_NAME: None} so sharding metadata tracks the extra axis automatically, and made the axis add/remove helpers idempotent. This avoids having to set transform_metadata manually or use explicit sharding for the Optimizer case. I’ll open a PR could you take a look?

mohsinm-dev avatar Dec 15 '25 18:12 mohsinm-dev

@mohsinm-dev that seems reasonable, but I'll have to check with @cgarciae , who will be on break until next week.

samanklesaria avatar Dec 16 '25 16:12 samanklesaria

@samanklesaria, cool. let me know after you will discuss with @cgarciae then we can discuss and work on it.

mohsinm-dev avatar Dec 16 '25 18:12 mohsinm-dev