pytreeclass
pytreeclass copied to clipboard
Visualize, create, and operate on pytrees in the most intuitive way possible.
Installation |Description |Quick Example |StatefulComputation |Benchamrks |Acknowledgements
π οΈ Installation
pip install pytreeclass
Install development version
pip install git+https://github.com/ASEM000/pytreeclass
π Description
pytreeclass
is a JAX-compatible class builder to create and operate on stateful JAX PyTrees in a performant and intuitive way, by building on familiar concepts found in numpy
, dataclasses
, and others.
See documentation and π³ Common recipes to check if this library is a good fit for your work. If you find the package useful consider giving it a π.
β© Quick Example
|
π Stateful computations
Under jax.jit jax requires states to be explicit, this means that for any class instance; variables needs to be separated from the class and be passed explictly. However when using TreeClass
no need to separate the instance variables ; instead the whole instance is passed as a state.
Using the following pattern,Updating state functionally can be achieved under jax.jit
|
Here, we define the update function. Since the increment method mutate the internal state, thus we need to use the functional approach to update the state by using .at
. To achieve this we can use .at[method_name].__call__(*args,**kwargs)
, this functional call will return the value of this call and a new model instance with the update state.
|
β Benchmarks
Benchmark simple training against `flax` and `equinox`
Training simple sequential linear benchmark against flax
and equinox
Num of layers | Flax/tc time |
Equinox/tc time |
10 | 1.427 | 6.671 |
100 | 1.1130 | 2.714 |