tree-math icon indicating copy to clipboard operation
tree-math copied to clipboard

Matrix support

Open shoyer opened this issue 2 years ago • 6 comments

I'd like to add a Matrix class to complement Vector.

A key design question is what this needs to support. In particular: do we need to support multiple axes that correspond to flattened pytrees, or is only a single axis enough?

If we only need to support a single "tree axis", then most Matrix operations can be implemented essentially by calling vmap on a Vector, and the implementation only needs to keep track of whether the "tree axis" on the underlying pytree is at the start or the end. This would suffice for use-cases like implementing L-BFGS or GMRES, which keep track of some fixed number of state vectors in the form of a matrix.

In contrast, multiple "tree axes" would be required to fully support use cases where both the inputs and outputs of a linear map correspond to (possible different) pytrees. For example, consider the outputs of jax.jacobian on a pytree -> pytree function. Here the implemention would need to be more complex to keep track of the separate tree definitions for inputs/outputs, similar to my first attempt at implementing a tree vectorizing transformation: https://github.com/google/jax/pull/3263.

My inclination is to only implement the "single tree-axis" version of matrix, which the reasoning being that it suffices to implement most "efficient" numerical algorithms on large-scale inputs, which cannot afford to use O(n^2) memory. On the other hand, it does preclude the interesting use-case of using tree-math to implement jax.jacobian (and variations).

shoyer avatar Dec 29 '21 23:12 shoyer

@Sohl-Dickstein seems to be in the camp of wanting support for multiple tree axes, to compute things like Jacobians, Hessians and covariances/correlations between pytrees.

shoyer avatar Dec 30 '21 18:12 shoyer

Weighing in here at @shoyer's request --

I'd put myself in the multiple-tree-axes camp as well, I think. (I am saying that primarily as an end user rather than a developer for the feature, though... !)

This starts to open up notions of "labelled tree axes". For example, performing the pytree-axes equivalent of jax.vmap(jax.vmap(operator.mul, in_axes=(0, None)), in_axes=(None, 0)) (an outer product), in which two different BatchTraces interact.

For a thorny reference problem in which this kind of stuff might get quite useful (tree-matrix vs tree-ndarray or otherwise), I'd suggest the Ito version of Milstein's method available here:

https://github.com/patrick-kidger/diffrax/blob/10b652e1d91518ac182e8d832ff309f7c199a9a0/diffrax/solver/milstein.py#L104

This is a pretty tricky implementation! It's very heavily annotated with comments descrbing the various tree-axes, normal-axes, and the way in which they interact.

patrick-kidger avatar Feb 07 '22 19:02 patrick-kidger

+1 to the @sohl-dickstein use case. Some more detail of where this would be handy: I recently needed to invert a Hessian of a function that took a pytree as its argument. The headache I ran into was that when I used jax.jacfwd(jax.jacrev(f))(x) to compute the Hessian, I got it as a pytree of pytrees, which turned out to be pretty complicated to flatten. It would be nice to be able to either transform a pytree of pytrees to and from a matrix of floats or to be able to perform matrix operations directly on the pytree of pytrees.

geoff-davis avatar Feb 15 '22 20:02 geoff-davis

Just wanted to chime in and say that I'd love this feature, and for my use cases (which are primarily about numerical solvers for non-convex problems) a single axis is all I'd need, though I'm sure I'd find uses in multi-axis implementation if that does get developed.

njwfish avatar May 18 '23 14:05 njwfish

So it's not a documented feature, but Equinox actually has a tree-math like sublibrary built-in, which can be used to do this kind of multi-axis stuff.

To set the scene, here is how it is used just to broadcast vector operations together:

from equinox.internal import ω

vector1 = [0, 1, (2, 3)]
vector2 = [4, 5, (6, 7)]
summed = (ω(vector1) + ω(vector2)).ω
# Alternate notation; I prefer this when doing pure arithmetic:
summed = (vector1**ω + vector2**ω).ω
print(summed)  # [4, 6, (8, 10)]

But with a bit of thinking you can nest these to accomplish higher-order operations:

# matrix has shape (2, 3)
matrix = ω([ω([0, 1, 2]), ω([3, 4, 5])])
# vector has shape (3,)
vector = ω([6, 7, 8])
# product (2, 3) @ (3,) -> (2,)     ("call" applies the specified function to every leaf of its pytree)
matvec = matrix.call(lambda row: sum((row * vector).ω))
# unwrap
matvec = matvec.ω
print(matvec)  # [23, 86]

The reason this works is that ω is not a PyTree. This means that matrix = ω([ω([0, 1, 2]), ω([3, 4, 5])]) doesn't have the outer ω looking inside the inner ωs. (I believe tree-math's Vector is a PyTree and that the same trick wouldn't work in this library, though.)

Conversely, this does mean that you musn't pass ω objects across JIT/grad/etc. API boundaries. (Whilst you can with tree-math.) ω is only meant to be used as a convenient syntax with the bounds of a single function.

patrick-kidger avatar May 18 '23 16:05 patrick-kidger

I do still think matrix support would be awesome to have, and I actually had a use-cases for this just last week.

That said, at this point I'm relatively unlikely to work on it. It somebody else wants to give this a try that would be very welcome!

shoyer avatar May 19 '23 00:05 shoyer