Cristian Garcia

Results 210 comments of Cristian Garcia

@jheek points out that this can be fixed by adding `intermediates` to `variable_axes`: ```python Cell = nn.scan( nn.LSTMCell, variable_broadcast="params", split_rngs={"params": False}, variable_axes={"intermediates": 1}, in_axes=1, out_axes=1, ) ``` However the internal...

@8bitmp3 thanks for the feedback! Made the most of the changes except I left "make forward pass" as it sounded better in my head (feel free to push back).

Its a good point. Maybe we need a `RNNCellBase.get_stochastic_mask` API and have cells optionally accept a `stochastic_mask` argument.

Some updates: creating a mask API for RNNs resulted in more complex code. Leveraging RNG collections with transformation like `nn.scan` results in very understandable code and effectively uses Flax existing...

@Chuxiaof thanks for the feedback, will include your suggestions :) @andsteing Maybe "create submodules inline" or "compact submodules"?

How about: ```python class MLP(nn.Module): # dataclass Modules out_dims: int @nn.compact def __call__(self, x): x = jnp.reshape(x, (x.shape[0], -1)) # good ol' numpy api x = nn.Dense(128)(x) # create submodules...

Should we should just remove the comment over `reshape` about the numpy api?

Hey @banda-larga, thanks for doing this! You are missing the `pre-commit` hooks to sync the markdown and jupyter notebooks as detailed in the [How to Contribute](https://flax.readthedocs.io/en/latest/advanced_topics/contributing.html) guide. Please run: ```bash...

Thanks @banda-larga! I was wondering if you could fix an additional one? After > The linear approximation of f at point $x$ reads: in `jax_for_the_impatient.md` we are also missing a...