Marcus Chiam
Marcus Chiam
Added Einsum layer. See API docs [here](https://flax--3710.org.readthedocs.build/en/3710/api_reference/flax.linen/layers.html#flax.linen.Einsum).
data:image/s3,"s3://crabby-images/1fe62/1fe62cec68c1565e72d6626aae5bbab7d42b2b7c" alt="norm" `LayerNorm` is understood as normalization the activations by reducing across all non-batch axes. Currently Flax's implementation of `LayerNorm`, the default [`reduction_axes=-1`](https://github.com/google/flax/blob/main/flax/linen/normalization.py#L430). This works for 2D inputs, but for higher...
In #3617, A user requested to allow a `kw_only` argument to be passed into the `flax.struct.dataclass` decorator. #3645 allows general `kwargs` to be passed into the `flax.struct.dataclass` decorator, originally with...
Python notebooks were deleted from their corresponding guides in #3434 because of discrepancy issues. We should consider the following: - [ ] add them again and have them be up-to-date...
test segfault
Removed `with mesh` context manager from Flax jit guide, since it doesn't do anything with jax.jit
Removed `with mesh` context manager from [Flax jit guide](https://flax--3303.org.readthedocs.build/en/3303/guides/flax_on_pjit.html#compile-the-train-step-and-inference), since it doesn't do anything with jax.jit
After the [dict migration](https://github.com/google/flax/pull/3193), Flax now returns regular dicts when calling the `.init`, `.init_with_output` and `.apply` Module methods. However the representation of regular dicts are not as readable compared to...
I submitted a PR to Flax that re-arranges import statements, and it causes a seg fault that seems to originate from `sentencepiece`. The PR can be found here: https://github.com/google/flax/pull/3442 Traceback...
A user posted in the Flax discussions about an orbax discrepancy between different zones in GCE. Do different zones have different orbax versions? ================================================================== # what happened When I save...
Moved `Module.iter_*` methods to `nnx.graph` (see [`nnx.graph.iter_nodes`](https://flax--4001.org.readthedocs.build/en/4001/api_reference/flax.nnx/graph.html#flax.nnx.iter_nodes) and [`nnx.graph.iter_child_nodes`](https://flax--4001.org.readthedocs.build/en/4001/api_reference/flax.nnx/graph.html#flax.nnx.iter_child_nodes) for more detail). ~Also added a [filter warning](https://github.com/google/flax/pull/4001/files#diff-50c86b7ed8ac2cf95bd48334961bf0530cdc77b5a56f852c5c61b89d735fd711R151-R152) for a deprecation error. See more detail [here](https://github.com/tensorflow/tensorflow/issues/69981).~