Results 16 issues of Marcus Chiam

Added Einsum layer. See API docs [here](https://flax--3710.org.readthedocs.build/en/3710/api_reference/flax.linen/layers.html#flax.linen.Einsum).

![norm](https://github.com/google/flax/assets/19753743/17bebc7b-c78c-4288-b101-258ea6ef7dbf) `LayerNorm` is understood as normalization the activations by reducing across all non-batch axes. Currently Flax's implementation of `LayerNorm`, the default [`reduction_axes=-1`](https://github.com/google/flax/blob/main/flax/linen/normalization.py#L430). This works for 2D inputs, but for higher...

In #3617, A user requested to allow a `kw_only` argument to be passed into the `flax.struct.dataclass` decorator. #3645 allows general `kwargs` to be passed into the `flax.struct.dataclass` decorator, originally with...

pull ready

Python notebooks were deleted from their corresponding guides in #3434 because of discrepancy issues. We should consider the following: - [ ] add them again and have them be up-to-date...

test segfault

Removed `with mesh` context manager from [Flax jit guide](https://flax--3303.org.readthedocs.build/en/3303/guides/flax_on_pjit.html#compile-the-train-step-and-inference), since it doesn't do anything with jax.jit

After the [dict migration](https://github.com/google/flax/pull/3193), Flax now returns regular dicts when calling the `.init`, `.init_with_output` and `.apply` Module methods. However the representation of regular dicts are not as readable compared to...

I submitted a PR to Flax that re-arranges import statements, and it causes a seg fault that seems to originate from `sentencepiece`. The PR can be found here: https://github.com/google/flax/pull/3442 Traceback...

A user posted in the Flax discussions about an orbax discrepancy between different zones in GCE. Do different zones have different orbax versions? ================================================================== # what happened When I save...

Moved `Module.iter_*` methods to `nnx.graph` (see [`nnx.graph.iter_nodes`](https://flax--4001.org.readthedocs.build/en/4001/api_reference/flax.nnx/graph.html#flax.nnx.iter_nodes) and [`nnx.graph.iter_child_nodes`](https://flax--4001.org.readthedocs.build/en/4001/api_reference/flax.nnx/graph.html#flax.nnx.iter_child_nodes) for more detail). ~Also added a [filter warning](https://github.com/google/flax/pull/4001/files#diff-50c86b7ed8ac2cf95bd48334961bf0530cdc77b5a56f852c5c61b89d735fd711R151-R152) for a deprecation error. See more detail [here](https://github.com/tensorflow/tensorflow/issues/69981).~