tree-math
tree-math copied to clipboard
Matrix support
I'd like to add a Matrix
class to complement Vector
.
A key design question is what this needs to support. In particular: do we need to support multiple axes that correspond to flattened pytrees, or is only a single axis enough?
If we only need to support a single "tree axis", then most Matrix
operations can be implemented essentially by calling vmap
on a Vector
, and the implementation only needs to keep track of whether the "tree axis" on the underlying pytree is at the start or the end. This would suffice for use-cases like implementing L-BFGS or GMRES, which keep track of some fixed number of state vectors in the form of a matrix.
In contrast, multiple "tree axes" would be required to fully support use cases where both the inputs and outputs of a linear map correspond to (possible different) pytrees. For example, consider the outputs of jax.jacobian
on a pytree -> pytree function. Here the implemention would need to be more complex to keep track of the separate tree definitions for inputs/outputs, similar to my first attempt at implementing a tree vectorizing transformation: https://github.com/google/jax/pull/3263.
My inclination is to only implement the "single tree-axis" version of matrix
, which the reasoning being that it suffices to implement most "efficient" numerical algorithms on large-scale inputs, which cannot afford to use O(n^2) memory. On the other hand, it does preclude the interesting use-case of using tree-math to implement jax.jacobian
(and variations).
@Sohl-Dickstein seems to be in the camp of wanting support for multiple tree axes, to compute things like Jacobians, Hessians and covariances/correlations between pytrees.
Weighing in here at @shoyer's request --
I'd put myself in the multiple-tree-axes camp as well, I think. (I am saying that primarily as an end user rather than a developer for the feature, though... !)
This starts to open up notions of "labelled tree axes". For example, performing the pytree-axes equivalent of jax.vmap(jax.vmap(operator.mul, in_axes=(0, None)), in_axes=(None, 0))
(an outer product), in which two different BatchTrace
s interact.
For a thorny reference problem in which this kind of stuff might get quite useful (tree-matrix vs tree-ndarray or otherwise), I'd suggest the Ito version of Milstein's method available here:
https://github.com/patrick-kidger/diffrax/blob/10b652e1d91518ac182e8d832ff309f7c199a9a0/diffrax/solver/milstein.py#L104
This is a pretty tricky implementation! It's very heavily annotated with comments descrbing the various tree-axes, normal-axes, and the way in which they interact.
+1 to the @sohl-dickstein use case. Some more detail of where this would be handy: I recently needed to invert a Hessian of a function that took a pytree as its argument. The headache I ran into was that when I used jax.jacfwd(jax.jacrev(f))(x) to compute the Hessian, I got it as a pytree of pytrees, which turned out to be pretty complicated to flatten. It would be nice to be able to either transform a pytree of pytrees to and from a matrix of floats or to be able to perform matrix operations directly on the pytree of pytrees.
Just wanted to chime in and say that I'd love this feature, and for my use cases (which are primarily about numerical solvers for non-convex problems) a single axis is all I'd need, though I'm sure I'd find uses in multi-axis implementation if that does get developed.
So it's not a documented feature, but Equinox actually has a tree-math like sublibrary built-in, which can be used to do this kind of multi-axis stuff.
To set the scene, here is how it is used just to broadcast vector operations together:
from equinox.internal import ω
vector1 = [0, 1, (2, 3)]
vector2 = [4, 5, (6, 7)]
summed = (ω(vector1) + ω(vector2)).ω
# Alternate notation; I prefer this when doing pure arithmetic:
summed = (vector1**ω + vector2**ω).ω
print(summed) # [4, 6, (8, 10)]
But with a bit of thinking you can nest these to accomplish higher-order operations:
# matrix has shape (2, 3)
matrix = ω([ω([0, 1, 2]), ω([3, 4, 5])])
# vector has shape (3,)
vector = ω([6, 7, 8])
# product (2, 3) @ (3,) -> (2,) ("call" applies the specified function to every leaf of its pytree)
matvec = matrix.call(lambda row: sum((row * vector).ω))
# unwrap
matvec = matvec.ω
print(matvec) # [23, 86]
The reason this works is that ω is not a PyTree. This means that matrix = ω([ω([0, 1, 2]), ω([3, 4, 5])])
doesn't have the outer ω looking inside the inner ωs. (I believe tree-math's Vector
is a PyTree and that the same trick wouldn't work in this library, though.)
Conversely, this does mean that you musn't pass ω objects across JIT/grad/etc. API boundaries. (Whilst you can with tree-math.) ω is only meant to be used as a convenient syntax with the bounds of a single function.
I do still think matrix support would be awesome to have, and I actually had a use-cases for this just last week.
That said, at this point I'm relatively unlikely to work on it. It somebody else wants to give this a try that would be very welcome!