Results 32 comments of Marcus Chiam

hi @sourabh2k15, we have a [`Bidirectional` class](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/layers.html#flax.linen.Bidirectional). Would that work for you? ``` module = nn.Bidirectional(nn.RNN(nn.GRUCell(5)), nn.RNN(nn.GRUCell(5))) x = jnp.ones((7, 3)) v = module.init(jax.random.PRNGKey(0), x) out = module.apply(v, x) ```

hi @Cogitans, I'm trying to add spectral normalization into Flax and am modeling it after the [Haiku version](https://github.com/google-deepmind/dm-haiku/blob/main/haiku/_src/spectral_norm.py#L66). I had some questions: - How is this used in a typical...

> Interesting. Why do we specialize 2D LayerNorm? Also, where do where do we specialize it? After internal discussion, we have decided to keep the default `reduction_axes=-1`

Can you run the pre-commit hook to fix formatting

FYI @zaccharieramzi, I added a `dropout_arg` to `nn.MultiHeadDotProductAttention` in #3384 so you can get the same dropout mask