Patrick Kidger

Results 107 issues of Patrick Kidger

Dev branch for the next release.

- [x] jaxtyping - [x] remove jax.dtypes.canonicalize_dtype - [x] sympy2jax - [x] ruff - [x] private modules - [x] `import foo as foo` - [x] abstract vars (+no more init=False)...

next

There's a few things we could do to tidy up their implementation even further. Most of these are primarily considered in the FIRK case, so it might need some thinking...

refactor

This relies on JAX implementing a way to detect symbolic zero tangents to a `custom_vjp`.

refactor
next

These don't really work properly at the moment. For example ```python class MySolver(AbstractWrappedSolver): ... diffeqsolve(solver=MySolver(Kvaerno3()), stepsize_controller=PIDController(...), ...) ``` does not work because `MySolver` doesn't have a `nonlinear_solver` attribute for the...

refactor
next

- Simplified `clear_caches`

- Support for autoregressive attention; - Includes support for zero-length queries, e.g. when populating the caches for the prompt. - Causal masking available by passing mask="causal"; - Support for multi-query...

Right now the information about `EQX_ON_ERROR` is added when the error is caught and re-raised by `eqx.filter_jit`. But for compatibility with `jax.jit` then I think we could probably just append...

next
refactor