tjax icon indicating copy to clipboard operation
tjax copied to clipboard

Tools for JAX

============= Tools for JAX

.. role:: bash(code) :language: bash

.. role:: python(code) :language: python

This repository implements a variety of tools for the differential programming library JAX <>_.

Major components

Tjax's major components are:

  • A dataclass <>_ and mypy_plugin <>_ decorator :python:dataclasss that facilitates defining structured JAX objects (so-called "pytrees"), which benefits from:

    • the ability to mark fields as static (not available in chex.dataclass),
    • a MyPy plugin, and
    • a display method that produces formatted text according to the tree structure.
  • A fixed_point <>_ finding library heavily based on fax <>_. Our library

    • supports stochastic iterated functions, and
    • uses dataclasses instead of closures to avoid leaking JAX tracers.
  • A shim <>_ for the gradient transformation library optax <>_ that supports:

    • easy differentiation and vectorization of “gradient transformation” (learning rule) parameters,
    • gradient transformation objects that can be passed dynamically to jitted functions, and
    • generic type annotations.

Minor components

Tjax also includes:

  • A pretty printer :python:print_generic for aggregate and vector types, including dataclasses. (See display <>_.)

  • Versions of :python:custom_vjp and :python:custom_jvp that support being used on methods. (See shims <>_.)

  • Tools for working with cotangents. (See cotangent_tools <>_.)

  • JAX tree registration for NetworkX <>_ graph types. (See graph <>_.)

  • Leaky integration :python:leaky_integrate and Ornstein-Uhlenbeck process iteration :python:diffused_leaky_integrate. (See leaky_integral <>_.)

  • An improved version of :python:jax.tree_util.Partial. (See partial <>_.)

  • A Matplotlib trajectory plotter :python:PlottableTrajectory. (See plottable_trajectory <>_.)

  • A testing function :python:assert_tree_allclose that automatically produces testing code. And, a related function :python:tree_allclose. (See testing <>_.)

  • Basic tools like :python:divide_where. (See tools <>_.)

Also, see the documentation <>_.

Contribution guidelines

  • Conventions: PEP8.

  • How to run tests: :bash:pytest .

  • How to clean the source:

    • :bash:isort tjax
    • :bash:pylint tjax
    • :bash:mypy tjax
    • :bash:flake8 tjax