Enzyme-JAX icon indicating copy to clipboard operation
Enzyme-JAX copied to clipboard

Results 112 Enzyme-JAX issues
Sort by recently updated
recently updated
newest added

To mark which ones we see worth doing, are doing / need to do cc @ivanradanov @ftynse - [x] iota reshape (becomes single iota) ``` %195 = stablehlo.iota dim =...

ImportError: Python version mismatch: module was compiled for Python 3.10, but the interpreter version is incompatible: 3.11.3 (main, Apr 19 2023, 18:51:09) [Clang 14.0.6 ].

Pip package installs v0.0.4, which is a problem since the new tests do not work in the old version. Should be a quick fix?

Extending the tests in https://github.com/EnzymeAD/Enzyme-JAX/blob/d9e2ae086aaac98a0e052a6d8e6d2ed163c00e6a/test/test.py#L68 with ``` > x = jax.jacrev(add_one)(jnp.array([1., 2., 3.]), jnp.array([1., 2., 3.])) NotImplementedError: Batching rule for 'enzyme_rev' not implemented ``` ``` > x = jax.jacfwd(add_one)(jnp.array([1., 2.,...

- fix rmsnorm to compute what it should - use complex numbers instead of rotation matrix - don't append

NOTE: Strikethrough ops are deliberately not annotated. - [ ] StableHLO - [x] AbsOp - [x] AddOp - [ ] ~AfterAllOp~ - [ ] AllGatherOp - [ ] AllReduceOp -...

good first issue
help wanted

Ran into the following error when trying to use the enzyme_jax export functionality: Note, we found this bug when running the pip installed version of jax so v0.0.8 ``` WARNING:...

This https://jax.readthedocs.io/en/latest/profiling.html goes to https://github.com/jax-ml/jax/blob/0093ba29d898e8f9829e5365485c1f0094f6b0c5/jax/_src/profiler.py#L128 which is implemented here https://github.com/openxla/xla/blob/c77c7c2c5922b83f219a650a24570f91ffe19f15/xla/python/profiler.cc#L170 Getting api's with this connected should probably fix issues for both reactant and the egraph side, if you guys have...