flax
flax copied to clipboard
Cannot assign arrays to dataclass fields in `nnx`
When instantiating a nnx.dataclass
module, if the input to a param_field
(or any variable_field
actually) is a jax Array, a ValueError is raised because the value is assigned to the module without being wrapped into the nnx.Param
class.
import jax
import jax.numpy as jnp
from flax.experimental import nnx
@nnx.dataclass
class Foo(nnx.Module):
x: jax.Array = nnx.param_field()
foo = Foo(jnp.array(0.2))
I would expect that the input is wrapped into nnx.Param
before being assigned to the module. Same as is works for e.g. integers or floats.
Logs, error messages, etc:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/users/fzanetta/pyprojects/GPJax/_debug/variables.ipynb Cell 2 line 5
[1] @nnx.dataclass
[2] class Foo(nnx.Module):
[3] x: jax.Array = nnx.param_field()
Foo(jnp.array(0.2))
File ~/.cache/pypoetry/virtualenvs/gpjax-giodCE1Q-py3.10/lib/python3.10/site-packages/flax/experimental/nnx/nnx/module.py:150, in ModuleMeta.__call__(self, *args, **kwargs)
149 def __call__(self, *args: Any, **kwargs: Any) -> Any:
--> 150 return self._meta_call(*args, **kwargs)
File ~/.cache/pypoetry/virtualenvs/gpjax-giodCE1Q-py3.10/lib/python3.10/site-packages/flax/experimental/nnx/nnx/module.py:155, in ModuleMeta._meta_call(cls, *args, **kwargs)
153 module = cls.__new__(cls, *args, **kwargs)
154 vars(module)['_module__state'] = ModuleState()
--> 155 module.__init__(*args, **kwargs)
157 if dataclasses.is_dataclass(module):
158 if isinstance(module, _HasSetup):
File <string>:3, in __init__(self, x)
File ~/.cache/pypoetry/virtualenvs/gpjax-giodCE1Q-py3.10/lib/python3.10/site-packages/flax/experimental/nnx/nnx/module.py:207, in Module.__setattr__(self, name, value)
206 def __setattr__(self, name: str, value: Any) -> None:
--> 207 self._setattr(name, value)
File ~/.cache/pypoetry/virtualenvs/gpjax-giodCE1Q-py3.10/lib/python3.10/site-packages/flax/experimental/nnx/nnx/module.py:232, in Module._setattr(self, name, value)
230 else:
231 if isinstance(value, (jax.Array, np.ndarray, State)):
--> 232 raise ValueError(
233 f\"Trying to assign a '{type(value).__name__}' to the Module\"
234 f\" attribute '{name}'. This is not supported. Non-hashable \"
235 'objects are not valid static state in JAX. Please wrap '
236 'the value in a Variable type instead.'
237 )
238 vars_dict[name] = value
ValueError: Trying to assign a 'ArrayImpl' to the Module attribute 'x'. This is not supported. Non-hashable objects are not valid static state in JAX. Please wrap the value in a Variable type instead."
}
Seems like this is intentional behavior as there's a line of code that catches whether the input is a jax.Array
or not. Any thoughts @cgarciae?
Seems like this is intentional behavior as there's a line of code that catches whether the input is a
jax.Array
or not. Any thoughts @cgarciae?
It's intentional that jax.Array
cannot be assigned directly, but I thought the point of using nnx.param_field
is that the jax.Array
is first wrapped into nnx.Param
before being assigned to the module. Same for nnx.variable_field
where the array would be wrapped into the specified variable.
If I understand correctly, first the arguments are assigned directly to the module https://github.com/google/flax/blob/3cd34b6b3dca48d197fbd2a9c8de2371b10a3cb2/flax/experimental/nnx/nnx/module.py#L155
and only in a second step (when using dataclasses)
https://github.com/google/flax/blob/3cd34b6b3dca48d197fbd2a9c8de2371b10a3cb2/flax/experimental/nnx/nnx/module.py#L157-L176
are they wrapped in the given variable container. So for integers, floats, etc. there are no problems during the first step, but if the argument is an array we have an error. If this is intentional, I wonder why?
Since #3720 you should pass the Param
directly. nnx.dataclasses
will be removed soon.