equinox
equinox copied to clipboard
Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Hi, I think the use of running_mean and running_var during training time in BatchNorm causes training instability and increased learning_rate sensitivity. With momentum low (say 0.5) the layer works fine...
Hey there! @patrick-kidger Thank you for providing this amazing library! Highly appreciated. When comparing the speed of `eqx.nn.Conv2d` and `torch.nn.Conv2d` I was a surprised to find that the jitted version...
Hello Patrick, Thanks a lot for all the working you are putting in here. I have been working on a problem that I can solve with small scale neural network...
```python key = jax.random.PRNGKey(0) # (B, C, H*W) inputs = jnp.ones((3, 16, 64)) linear_layer = jax.vmap(eqx.nn.Linear( 16, 16, use_bias=False, key=key )) outputs = linear_layer(inputs) print(outputs.shape) ``` The above mentioned code...
Added RoPE embeddings from [the RoFormer paper](https://arxiv.org/pdf/2104.09864.pdf). I need to add this to my transformer to perform some tests first before I can mark this as ready. Also if it's...
I am using the following code snippet (specifically the function **load_torch_weights**) but it uses some equinox methods(**eqx.experimental.set_state** and **eqx.experimental.StateIndex**) which seem no longer supported in the latest version of equinox....
- Support for autoregressive attention; - Includes support for zero-length queries, e.g. when populating the caches for the prompt. - Causal masking available by passing mask="causal"; - Support for multi-query...
Hello, I'm a new user of the equinox library so maybe my problem is obvious to solve but I cannot figure out how to do that properly. Basically I'm trying...
Hello Patrick, again thank you for the nice package. I wanted to ask whether there exists a way to deserialise an Equinox-trained model (in eqx format [json+bytes]) to be used...
Dear All- I have a very simple question. I have two neural networks of type `MLP` and I want to initialize optimizer via `optax`. When I have one neural network...