Jonathan Brodrick

Results 21 issues of Jonathan Brodrick

WIP to fix #145 Tridiagonal solve uses `jax.lax.linalg.tridiagonal_solve` when on GPU and the `lineax` implementation with `unroll=1` on CPU using JAX primitives. I think my `abstract_eval` is not general enough...

The current implementation of tridiagonal solve in `lineax` uses a [manually implemented version](lineax/_solver/tridiagonal.py) of the Thomas algorithm with `jax.lax.scan` and a hard-coded `unroll` of 32. There is now a [`jax.linalg.tridiagonal_solve`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.linalg.tridiagonal_solve.html)...

Would be great to use progress meters such as `tqdm` with a callback such as with `diffrax` or `scipy.optimize.minimize`.

feature

I'm only just trying to understand these now, so not sure how to check this is correct. Feel free to edit/suggest changes. My main aim here is to iteratively build...

ErgoExo now takes `args` as an additional argument to `__call__` and `value_and_grad`, enabling another way to "train" opposed to `modules`. Abstracted `__call__` by adding `_mlflow_func` abstraction which could work for...

Currently, the same test run on different trials (i.e. `0001`/`0002`) only differ in their metrics in the `pytest-benchmark compare --csv` file. I'm assuming if I use `--sort fullname` I can...

Often times we compute auxiliary variables each timestep within `vector_field` that we would like to save as part of processing for future reference/sanity checking. We would like to avoid having...

As far as I understand, there are two main ways of inspecting the progress of a `diffeqsolve`: * `progress_meter` this is called every timestep with the ability to use jax...

question

We've run into a strange error when using `jnp.interp` and debugging `nan`s with `jit` disabled. I'm aware that these issues can sometimes arise and enabling only one of these flags...

question

I am using an `equinox.Module` to define a function that allows differentiation with respect to its scalar parameters. This function is much more efficiently written in terms of _many_ scalar...

question