flax icon indicating copy to clipboard operation
flax copied to clipboard

Cannot assign arrays to dataclass fields in `nnx`

Open frazane opened this issue 1 year ago • 3 comments

When instantiating a nnx.dataclass module, if the input to a param_field (or any variable_field actually) is a jax Array, a ValueError is raised because the value is assigned to the module without being wrapped into the nnx.Param class.

import jax
import jax.numpy as jnp
from flax.experimental import nnx

@nnx.dataclass
class Foo(nnx.Module):
    x: jax.Array = nnx.param_field()

foo = Foo(jnp.array(0.2))

I would expect that the input is wrapped into nnx.Param before being assigned to the module. Same as is works for e.g. integers or floats.

Logs, error messages, etc:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/users/fzanetta/pyprojects/GPJax/_debug/variables.ipynb Cell 2 line 5
      [1] @nnx.dataclass
      [2] class Foo(nnx.Module):
      [3]        x: jax.Array = nnx.param_field()
Foo(jnp.array(0.2))

File ~/.cache/pypoetry/virtualenvs/gpjax-giodCE1Q-py3.10/lib/python3.10/site-packages/flax/experimental/nnx/nnx/module.py:150, in ModuleMeta.__call__(self, *args, **kwargs)
    149 def __call__(self, *args: Any, **kwargs: Any) -> Any:
--> 150   return self._meta_call(*args, **kwargs)

File ~/.cache/pypoetry/virtualenvs/gpjax-giodCE1Q-py3.10/lib/python3.10/site-packages/flax/experimental/nnx/nnx/module.py:155, in ModuleMeta._meta_call(cls, *args, **kwargs)
    153 module = cls.__new__(cls, *args, **kwargs)
    154 vars(module)['_module__state'] = ModuleState()
--> 155 module.__init__(*args, **kwargs)
    157 if dataclasses.is_dataclass(module):
    158   if isinstance(module, _HasSetup):

File <string>:3, in __init__(self, x)

File ~/.cache/pypoetry/virtualenvs/gpjax-giodCE1Q-py3.10/lib/python3.10/site-packages/flax/experimental/nnx/nnx/module.py:207, in Module.__setattr__(self, name, value)
    206 def __setattr__(self, name: str, value: Any) -> None:
--> 207   self._setattr(name, value)

File ~/.cache/pypoetry/virtualenvs/gpjax-giodCE1Q-py3.10/lib/python3.10/site-packages/flax/experimental/nnx/nnx/module.py:232, in Module._setattr(self, name, value)
    230 else:
    231   if isinstance(value, (jax.Array, np.ndarray, State)):
--> 232     raise ValueError(
    233       f\"Trying to assign a '{type(value).__name__}' to the Module\"
    234       f\" attribute '{name}'. This is not supported. Non-hashable \"
    235       'objects are not valid static state in JAX. Please wrap '
    236       'the value in a Variable type instead.'
    237     )
    238   vars_dict[name] = value

ValueError: Trying to assign a 'ArrayImpl' to the Module attribute 'x'. This is not supported. Non-hashable objects are not valid static state in JAX. Please wrap the value in a Variable type instead."
}

frazane avatar Jan 17 '24 13:01 frazane

Seems like this is intentional behavior as there's a line of code that catches whether the input is a jax.Array or not. Any thoughts @cgarciae?

chiamp avatar Jan 25 '24 02:01 chiamp

Seems like this is intentional behavior as there's a line of code that catches whether the input is a jax.Array or not. Any thoughts @cgarciae?

It's intentional that jax.Array cannot be assigned directly, but I thought the point of using nnx.param_field is that the jax.Array is first wrapped into nnx.Param before being assigned to the module. Same for nnx.variable_field where the array would be wrapped into the specified variable.

If I understand correctly, first the arguments are assigned directly to the module https://github.com/google/flax/blob/3cd34b6b3dca48d197fbd2a9c8de2371b10a3cb2/flax/experimental/nnx/nnx/module.py#L155

and only in a second step (when using dataclasses)

https://github.com/google/flax/blob/3cd34b6b3dca48d197fbd2a9c8de2371b10a3cb2/flax/experimental/nnx/nnx/module.py#L157-L176

are they wrapped in the given variable container. So for integers, floats, etc. there are no problems during the first step, but if the argument is an array we have an error. If this is intentional, I wonder why?

frazane avatar Jan 25 '24 09:01 frazane

Since #3720 you should pass the Param directly. nnx.dataclasses will be removed soon.

cgarciae avatar Mar 07 '24 19:03 cgarciae