Cristian Garcia

Results 210 comments of Cristian Garcia

Hey @PhilipVinc thanks for looking into this! Edit: LGTM, definitely much better :)

Hey @aaarrti, currently `nn.jit` only support positional arguments which is why you are getting this error. However, after you pass training as positional you get another error which is that...

That said, its pretty rare to use `nn.jit`. Usually you just use `jax.jit` over the `train_step` as shown in the [Quick Start](https://flax.readthedocs.io/en/latest/getting_started.html).

Hey @epignatelli, thanks for taking interest! Historically Flax has tried to reduce the use of combinators, even `Sequential` was added very late. Not sure this proposed direction aligns well with...

Hey @hilanzy, cell's call signature is `(carry, inputs)` so its compatible with `nn.scan` if you need to rollout the cell yourself, meanwhile (as you point out) `RNN`'s input carry is...

Hey! I ran this code locally and it works: ```python import jax import numpy as np import optax import orbax.checkpoint from jax import numpy as jnp from jax import random...

I remember that before you used to monkey patch asyncio so orbax worked on jupyter/colab, maybe something similar is happening here? As a note this code runs ok in colab.

You should post an issue on the [orbax](https://github.com/google/orbax) repo.

@kvablack its a very good point. I don't know what was the original reason for this design, I'll ask the team and get back.

Since #3720 you should pass the `Param` directly. `nnx.dataclasses` will be removed soon.