equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/

Results 159 equinox issues
Sort by recently updated
recently updated
newest added

Hey Patrick, I have been using Equinox for one of my projects and up until now it has helped immensely in using JAX effectively and seamlessly. For this project, I...

documentation

Hello! I hope to specify the padding in `Conv` by strings like `'SAME'` or `'VALID'`. This is compatible with jax API since it's already implemented in `jax.lax.conv_general_dilated`. Furthermore, I also...

feature

Hi there, I had just translated the Mamba layer from [here](https://github.com/johnma2006/mamba-minimal) to [Equinox](https://github.com/Artur-Galstyan/kira/blob/mamba/kira/model/mamba.py). Would you accept a PR for this? PS: To get the most out of Mamba, we'd need...

Right now the information about `EQX_ON_ERROR` is added when the error is caught and re-raised by `eqx.filter_jit`. But for compatibility with `jax.jit` then I think we could probably just append...

next
refactor

does it run any of the public GPT models, or are the data structures fundamentally incompatible?

question

While checking this PR #568, I noticed that the "process_heads" part actually shouldn't be part of the RoPE embeddings PR as it's a separate thing. In theory, you could process...

Hello, I have been reading through and trying to understand the abstract/final design pattern that equinox recommends: https://docs.kidger.site/equinox/pattern/. There is one part I am wondering if you could provide further...

question

The following is not allowed in the abstract/final pattern ```python import abc import equinox as eqx some_function = lambda arg1: ... class AbstractModule(eqx.Module, strict=True): @abc.abstractmethod def __init__(arg1, arg2): raise NotImplementedError...

refactor

The following code seems to fail ``` import numpy as np from torch import nn import torch import jax import equinox as eqx print(torch.__version__, eqx.__version__) random_input = jax.random.normal(jax.random.PRNGKey(1), (10, 90,...

So I'm trying to do two things here: 1. Raise a error through the jit boundary using `eqx.error_if`. 2. Print a pytree using the `jax.debug.print` function. I'm looking to do...

feature
question