equinox
equinox copied to clipboard
Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Hello, Thank you for publishing a fantastic library! Is it possible to use named axes operations (xmap) within __init__ and __call__ of nn .module and then xmap over all created...
Hello! Are there an examples of doing half or mixed precision? If I'm just using nn.Linear layers, I'm guessing the most straightforward thing to do would be to make a...
I'd like to use equinox for some fairly large-scale training runs, but the state for those models is often too large to fit on a single accelerator, so gathering all...
When checking if ConvTranspose is actually computing the transpose operation, it seems to be failing. I've tried different weight matrix shapes, but I'm uncertain as to why this is failing:...
Hello, Is it possible to define custom jvp/vjp rules for a Module? For example, suppose I have the following module that scales inputs before passing into another module: ```python class...
Hi Patrick, I've incorporated the changes. It's not 100% done yet, but I'd especially like you to double check the error handling regarding the shapes in the MHA file.
In the [`eqx.internal.scan`](https://github.com/patrick-kidger/equinox/blob/62304f5c5d1aaf0e45fcc9c27ba6a54f0e358b4e/equinox/internal/_loop/loop.py#L131) do we not have a way to `unroll` the `scan`? I'm not sure what the constraint on the "checkpointed" `scan` since I'm not familiar with the algorithms...
Since [jax 0.4.27](https://jax.readthedocs.io/en/latest/changelog.html#jax-0-4-27-may-7-2024), several tests fail with: ``` args = (_ClosureConvert( jaxpr={ lambda ; a:f32[47] b:f32[] c:i32[] d:bool[] e:bool[] f:i32[] g:f32[] h:f32[47] i:bool[...t 0x7fff4c435a90>, _makes_false_steps=False ), Tracedwith)))) kwds = {}...
Sorry to post issue, but didn't see discussion section on the repo. What are your thoughts on https://flax.readthedocs.io/en/latest/experimental/nnx/index.html and it's scope relative to equinox? Do you see it as a...
I'm working on a little project that can ease PyTorch model conversion to your own JAX model Shameless advertisement and after setting all the weights, and biases and states, I...