flax
flax copied to clipboard
Metadata in `nnx.param_field` is not passed to the Param class
When I create a new parameter I can pass extra keyword arguments as metadata:
import dataclasses
from flax.experimental import nnx
foo = nnx.Param(
1.0,
domain = "positive_real",
)
print(foo)
Param(
value=1.0,
domain='positive_real'
)
I would expect that when passing a dictionary of the corresponding kwargs to nnx.param_field
, they would be passed down to the Param
class, but they are not:
@nnx.dataclass
class Bar(nnx.Module):
foo: jax.Array = nnx.param_field(1.0, metadata={"domain": "positive_real"})
bar = Bar()
print(bar.variables)
VariablesMapping{
foo: Param(
value=1.0
)
}
However they can be found here: dataclasses.fields(bar)[0].metadata["domain"]
.
I am just wondering if this is expected or a bug. Either way, I believe it would make sense to have this possibility. I have a quick fix (change the signature of the nnx_variable_constructor
and actually pass the metadata in ModuleMeta
) for it and I'd be happy to open a PR if you welcome this change :)
P.S.: another possibility would be to slightly change the nnx.param_field
function signature: metadata
becomes **metadata
, this way one could do something like:
@nnx.dataclass
class Bar(nnx.Module):
foo: jax.Array = nnx.param_field(1.0, domain="positive_real")