tjax
tjax copied to clipboard
Tools for JAX
============= Tools for JAX
.. role:: bash(code) :language: bash
.. role:: python(code) :language: python
This repository implements a variety of tools for the differential programming library
JAX <https://github.com/google/jax>
_.
Major components
Tjax's major components are:
-
A
dataclass <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/dataclasses>
_ andmypy_plugin <https://github.com/NeilGirdhar/tjax/blob/master/tjax/mypy_plugin.py>
_ decorator :python:dataclasss
that facilitates defining structured JAX objects (so-called "pytrees"), which benefits from:- the ability to mark fields as static (not available in
chex.dataclass
), - a MyPy plugin, and
- a display method that produces formatted text according to the tree structure.
- the ability to mark fields as static (not available in
-
A
fixed_point <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/fixed_point>
_ finding library heavily based onfax <https://github.com/gehring/fax>
_. Our library- supports stochastic iterated functions, and
- uses dataclasses instead of closures to avoid leaking JAX tracers.
-
A
shim <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/gradient>
_ for the gradient transformation libraryoptax <https://github.com/deepmind/optax>
_ that supports:- easy differentiation and vectorization of “gradient transformation” (learning rule) parameters,
- gradient transformation objects that can be passed dynamically to jitted functions, and
- generic type annotations.
Minor components
Tjax also includes:
-
A pretty printer :python:
print_generic
for aggregate and vector types, including dataclasses. (Seedisplay <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/display.py>
_.) -
Versions of :python:
custom_vjp
and :python:custom_jvp
that support being used on methods. (Seeshims <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/shims.py>
_.) -
Tools for working with cotangents. (See
cotangent_tools <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/cotangent_tools.py>
_.) -
JAX tree registration for
NetworkX <https://networkx.github.io/>
_ graph types. (Seegraph <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/graph.py>
_.) -
Leaky integration :python:
leaky_integrate
and Ornstein-Uhlenbeck process iteration :python:diffused_leaky_integrate
. (Seeleaky_integral <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/leaky_integral.py>
_.) -
An improved version of :python:
jax.tree_util.Partial
. (Seepartial <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/partial.py>
_.) -
A Matplotlib trajectory plotter :python:
PlottableTrajectory
. (Seeplottable_trajectory <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/plottable_trajectory.py>
_.) -
A testing function :python:
assert_tree_allclose
that automatically produces testing code. And, a related function :python:tree_allclose
. (Seetesting <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/testing.py>
_.) -
Basic tools like :python:
divide_where
. (Seetools <https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/tools.py>
_.)
Also, see the documentation <https://neilgirdhar.github.io/tjax/tjax/index.html>
_.
Contribution guidelines
-
Conventions: PEP8.
-
How to run tests: :bash:
pytest .
-
How to clean the source:
- :bash:
isort tjax
- :bash:
pylint tjax
- :bash:
mypy tjax
- :bash:
flake8 tjax
- :bash: