flax icon indicating copy to clipboard operation
flax copied to clipboard

Metadata in `nnx.param_field` is not passed to the Param class

Open frazane opened this issue 1 year ago • 0 comments

When I create a new parameter I can pass extra keyword arguments as metadata:

import dataclasses
from flax.experimental import nnx

foo = nnx.Param(
    1.0,
    domain = "positive_real",
)

print(foo)
Param(
  value=1.0,
  domain='positive_real'
)

I would expect that when passing a dictionary of the corresponding kwargs to nnx.param_field, they would be passed down to the Param class, but they are not:

@nnx.dataclass
class Bar(nnx.Module):
    foo: jax.Array = nnx.param_field(1.0, metadata={"domain": "positive_real"})

bar = Bar()
print(bar.variables)
VariablesMapping{
  foo: Param(
    value=1.0
  )
}

However they can be found here: dataclasses.fields(bar)[0].metadata["domain"].

I am just wondering if this is expected or a bug. Either way, I believe it would make sense to have this possibility. I have a quick fix (change the signature of the nnx_variable_constructor and actually pass the metadata in ModuleMeta) for it and I'd be happy to open a PR if you welcome this change :)

P.S.: another possibility would be to slightly change the nnx.param_field function signature: metadata becomes **metadata, this way one could do something like:

@nnx.dataclass
class Bar(nnx.Module):
    foo: jax.Array = nnx.param_field(1.0, domain="positive_real")

frazane avatar Dec 20 '23 22:12 frazane