equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/

Results 159 equinox issues
Sort by recently updated
recently updated
newest added

As a minimal example, lets say we have something like this (the real situation is custom modules which contain other modules, hence the flatten). If I look at flat_model obviously...

question

I am enjoying Equinox a great deal and it seems a huge ergonomic improvement over raw JAX in many cases. Thank you for building it! I am using `Equinox` to...

question

I am really enjoying using `filter_vmap` and `filter_pmap` to very easily transform functions that act on PyTrees – Truly a game-changer! Playing around on a Google Cloud TPU, I have...

question

Hello, Patrick currently, data classes allow fields to be either class or instance variables for example ```python @dc.dataclass class Test: a:int = 1 def __init__(self) -> None: pass Test().__dict__ #...

- [ ] Have it support classical treeverse, for improved results when using `equinox.internal.scan(..., kind="checkpointed")`. - [ ] Have it support saving the residuals of each step, rather than just...

feature

Consider the following code: ```python import jax import jax.numpy as jnp import equinox as eqx from functools import partial class Model(eqx.Module): def f(self): return model = Model() # We get...

question

It would be nice if there were a mechanism to state which `eqx.Module` members are dynamic and a filter transformation that treats only elements marked dynamic as dynamic. This would...

question

I am trying to figure out if it is possible to efficiently use something like the frozen parameters approach described here https://docs.kidger.site/equinox/examples/frozen_layer/ , to setup up a training loop such...

question

Hi, This week I started using equinox (great project!), and am slowly figuring out the best practices. I'm trying to do a simple L2 norm of the weights, which requires...

question

I was surprised about the runtime of a posterior estimation I was coding. Here is a MWA to reproduce the issue: ```python import jax.numpy as jnp import equinox as eqx...

question