Results 16 issues of Marcus Chiam

Added `self.param` to `nnx.compat`. Also added `nnx.compat` to API reference. Things to consider: * to best mirror `self.param` from Linen, we need to call `self.rngs.params()` implicitly in the `self.param` method,...

promote nnx to root-level import

pull ready

Resolves #3842. Remove `tree_map` deprecation warning. This was originally added because CLU was causing a [deprecation warning](https://github.com/google/flax/actions/runs/8546807944/job/23417866279?pr=3823#step:9:468) at HEAD, but after [fixing it](https://github.com/google/CommonLoopUtils/pull/342) and pushing a new release, this is...

Context: * As of [JAX 0.4.26](https://jax.readthedocs.io/en/latest/changelog.html#jax-0-4-26-april-3-2024), `jax.tree_map` is deprecated * #3823 renames all `jax.tree_map` usages to `jax.tree_util.tree_map` in Flax, however we get an [error](https://github.com/google/flax/actions/runs/8562140920/job/23464811629?pr=3797#step:9:467) in CI because of a CLU...

Priority: P2 - no schedule

Ensure doctest runs on NNX and fix NNX docstrings. Fixed some NNX examples and toy examples, but will [exclude](https://github.com/google/flax/pull/3797/files#diff-e39ea64563ca0b82835282545d8fc6e01b54cbdec1a7ad9f9072e16b45f6abdbR88-R89) testing on those until they are all fixed in a future...

Follow-up from #3735. Partial fix in #3772. Minimum repro: ``` import jax, jax.numpy as jnp from flax import linen as nn model = nn.WeightNorm(nn.Dense(3)) x = jnp.ones((1, 2)) key =...

Priority: P1 - soon